mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Enable flake8-comprehension rule C417 (#97880)
Enables flake8-comprehension rule C417. Ruff autogenerated these fixes to the codebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97880 Approved by: https://github.com/ezyang, https://github.com/kit1980, https://github.com/albanD
This commit is contained in:
parent
1d08b5b103
commit
47dca20d80
2
.flake8
2
.flake8
|
|
@ -14,7 +14,7 @@ ignore =
|
||||||
# these ignores are from flake8-bugbear; please fix!
|
# these ignores are from flake8-bugbear; please fix!
|
||||||
B007,B008,B017,B019,B020,B023,B024,B026,B027,B028,B903,B904,B905,B906,B907
|
B007,B008,B017,B019,B020,B023,B024,B026,B027,B028,B903,B904,B905,B906,B907
|
||||||
# these ignores are from flake8-comprehensions; please fix!
|
# these ignores are from flake8-comprehensions; please fix!
|
||||||
C407,C417
|
C407
|
||||||
# these ignores are from flake8-logging-format; please fix!
|
# these ignores are from flake8-logging-format; please fix!
|
||||||
G001,G002,G003,G004,G100,G101,G200,G201,G202
|
G001,G002,G003,G004,G100,G101,G200,G201,G202
|
||||||
per-file-ignores =
|
per-file-ignores =
|
||||||
|
|
|
||||||
12
.github/scripts/run_torchbench.py
vendored
12
.github/scripts/run_torchbench.py
vendored
|
|
@ -119,7 +119,7 @@ def get_valid_userbenchmarks(torchbench_path: str) -> List[str]:
|
||||||
[os.path.join(ub_path, ubdir) for ubdir in os.listdir(ub_path)],
|
[os.path.join(ub_path, ubdir) for ubdir in os.listdir(ub_path)],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
valid_ubs = list(map(lambda x: os.path.basename(x), ubs))
|
valid_ubs = [os.path.basename(x) for x in ubs]
|
||||||
return valid_ubs
|
return valid_ubs
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -130,13 +130,13 @@ def extract_models_from_pr(
|
||||||
userbenchmark_list = []
|
userbenchmark_list = []
|
||||||
pr_list = []
|
pr_list = []
|
||||||
with open(prbody_file, "r") as pf:
|
with open(prbody_file, "r") as pf:
|
||||||
lines = map(lambda x: x.strip(), pf.read().splitlines())
|
lines = (x.strip() for x in pf.read().splitlines())
|
||||||
magic_lines = list(filter(lambda x: x.startswith(MAGIC_PREFIX), lines))
|
magic_lines = list(filter(lambda x: x.startswith(MAGIC_PREFIX), lines))
|
||||||
if magic_lines:
|
if magic_lines:
|
||||||
# Only the first magic line will be recognized.
|
# Only the first magic line will be recognized.
|
||||||
pr_list = list(
|
pr_list = [
|
||||||
map(lambda x: x.strip(), magic_lines[0][len(MAGIC_PREFIX) :].split(","))
|
x.strip() for x in magic_lines[0][len(MAGIC_PREFIX) :].split(",")
|
||||||
)
|
]
|
||||||
valid_models = get_valid_models(torchbench_path)
|
valid_models = get_valid_models(torchbench_path)
|
||||||
valid_ubs = get_valid_userbenchmarks(torchbench_path)
|
valid_ubs = get_valid_userbenchmarks(torchbench_path)
|
||||||
for pr_bm in pr_list:
|
for pr_bm in pr_list:
|
||||||
|
|
@ -158,7 +158,7 @@ def extract_models_from_pr(
|
||||||
def find_torchbench_branch(prbody_file: str) -> str:
|
def find_torchbench_branch(prbody_file: str) -> str:
|
||||||
branch_name: str = ""
|
branch_name: str = ""
|
||||||
with open(prbody_file, "r") as pf:
|
with open(prbody_file, "r") as pf:
|
||||||
lines = map(lambda x: x.strip(), pf.read().splitlines())
|
lines = (x.strip() for x in pf.read().splitlines())
|
||||||
magic_lines = list(
|
magic_lines = list(
|
||||||
filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines)
|
filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -423,7 +423,7 @@ def tensor_is_on_xla(tensors):
|
||||||
if not isinstance(tensors, (tuple, list)):
|
if not isinstance(tensors, (tuple, list)):
|
||||||
tensors = [tensors]
|
tensors = [tensors]
|
||||||
tensors = [x for x in tensors if isinstance(x, torch.Tensor)]
|
tensors = [x for x in tensors if isinstance(x, torch.Tensor)]
|
||||||
return any(map(lambda x: x.device.type == "xla", tensors))
|
return any((x.device.type == "xla" for x in tensors))
|
||||||
|
|
||||||
|
|
||||||
def timed(
|
def timed(
|
||||||
|
|
@ -757,12 +757,9 @@ def speedup_experiment_ds(args, model_iter_fn, model, example_inputs):
|
||||||
shapes = [x[0].shape for x in example_inputs]
|
shapes = [x[0].shape for x in example_inputs]
|
||||||
shape_keys = sorted(set(shapes))
|
shape_keys = sorted(set(shapes))
|
||||||
shape_speedups = {
|
shape_speedups = {
|
||||||
shape: list(
|
shape: [
|
||||||
map(
|
it[1] for it in filter(lambda it: it[0] == shape, zip(shapes, speedups))
|
||||||
lambda it: it[1],
|
]
|
||||||
filter(lambda it: it[0] == shape, zip(shapes, speedups)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for shape in shape_keys
|
for shape in shape_keys
|
||||||
}
|
}
|
||||||
output_str = (
|
output_str = (
|
||||||
|
|
|
||||||
|
|
@ -351,7 +351,7 @@ def get_skip_tests(suite):
|
||||||
if hasattr(module, "SKIP_TRAIN"):
|
if hasattr(module, "SKIP_TRAIN"):
|
||||||
skip_tests.update(module.SKIP_TRAIN)
|
skip_tests.update(module.SKIP_TRAIN)
|
||||||
|
|
||||||
skip_tests = map(lambda name: f"-x {name}", skip_tests)
|
skip_tests = (f"-x {name}" for name in skip_tests)
|
||||||
skip_str = " ".join(skip_tests)
|
skip_str = " ".join(skip_tests)
|
||||||
return skip_str
|
return skip_str
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ def sparse_grad_output(a, b):
|
||||||
def read_matrix_params(path):
|
def read_matrix_params(path):
|
||||||
with open(path, 'r') as file:
|
with open(path, 'r') as file:
|
||||||
line = file.readline()
|
line = file.readline()
|
||||||
nrows, ncols, nnz = map(lambda el: int(el), line.split(', '))
|
nrows, ncols, nnz = (int(el) for el in line.split(', '))
|
||||||
return (nrows, ncols), nnz
|
return (nrows, ncols), nnz
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -39,9 +39,9 @@ def csr_to_coo(indices, indptr, shape):
|
||||||
|
|
||||||
def load_sparse_matrix(path, device):
|
def load_sparse_matrix(path, device):
|
||||||
with open(path, 'r') as file:
|
with open(path, 'r') as file:
|
||||||
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
|
nrows, ncols, nnz = (int(el) for el in file.readline().split(', '))
|
||||||
index_pointers = map(lambda el: int(el), file.readline().split())
|
index_pointers = (int(el) for el in file.readline().split())
|
||||||
indices = map(lambda el: int(el), file.readline().split())
|
indices = (int(el) for el in file.readline().split())
|
||||||
|
|
||||||
index_pointers = list(index_pointers)
|
index_pointers = list(index_pointers)
|
||||||
indices = list(indices)
|
indices = list(indices)
|
||||||
|
|
@ -52,17 +52,17 @@ def load_sparse_matrix(path, device):
|
||||||
|
|
||||||
def gen_vector(path, device):
|
def gen_vector(path, device):
|
||||||
with open(path, 'r') as file:
|
with open(path, 'r') as file:
|
||||||
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
|
nrows, ncols, nnz = (int(el) for el in file.readline().split(', '))
|
||||||
index_pointers = map(lambda el: int(el), file.readline().split())
|
index_pointers = (int(el) for el in file.readline().split())
|
||||||
indices = map(lambda el: int(el), file.readline().split())
|
indices = (int(el) for el in file.readline().split())
|
||||||
return torch.randn(nrows, dtype=torch.double, device=device)
|
return torch.randn(nrows, dtype=torch.double, device=device)
|
||||||
|
|
||||||
|
|
||||||
def gen_matrix(path, device):
|
def gen_matrix(path, device):
|
||||||
with open(path, 'r') as file:
|
with open(path, 'r') as file:
|
||||||
nrows, ncols, nnz = map(lambda el: int(el), file.readline().split(', '))
|
nrows, ncols, nnz = (int(el) for el in file.readline().split(', '))
|
||||||
index_pointers = map(lambda el: int(el), file.readline().split())
|
index_pointers = (int(el) for el in file.readline().split())
|
||||||
indices = map(lambda el: int(el), file.readline().split())
|
indices = (int(el) for el in file.readline().split())
|
||||||
return torch.randn(nrows, ncols, dtype=torch.double, device=device)
|
return torch.randn(nrows, ncols, dtype=torch.double, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -499,9 +499,9 @@ class PartialMaml(torch.nn.Module):
|
||||||
logits = net(x_spt)
|
logits = net(x_spt)
|
||||||
loss = F.cross_entropy(logits, y_spt)
|
loss = F.cross_entropy(logits, y_spt)
|
||||||
grad = torch.autograd.grad(loss, net.parameters())
|
grad = torch.autograd.grad(loss, net.parameters())
|
||||||
fast_weights = list(
|
fast_weights = [
|
||||||
map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))
|
p[1] - self.update_lr * p[0] for p in zip(grad, net.parameters())
|
||||||
)
|
]
|
||||||
|
|
||||||
# this is the loss and accuracy before first update
|
# this is the loss and accuracy before first update
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
||||||
|
|
@ -3427,7 +3427,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||||
check_shape_only = op.name in ('empty_like', 'new_empty')
|
check_shape_only = op.name in ('empty_like', 'new_empty')
|
||||||
for sample_input in sample_inputs_itr:
|
for sample_input in sample_inputs_itr:
|
||||||
args = (sample_input.input,) + sample_input.args
|
args = (sample_input.input,) + sample_input.args
|
||||||
if not any(map(lambda arg: isinstance(arg, torch.Tensor), args)):
|
if not any((isinstance(arg, torch.Tensor) for arg in args)):
|
||||||
# Atleast one tensor required for vmap.
|
# Atleast one tensor required for vmap.
|
||||||
continue
|
continue
|
||||||
kwargs = sample_input.kwargs
|
kwargs = sample_input.kwargs
|
||||||
|
|
|
||||||
|
|
@ -1346,7 +1346,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
|
||||||
# assert that each of the desired modules have the observers inserted
|
# assert that each of the desired modules have the observers inserted
|
||||||
for fqn, module in prepared_for_callibrate_model.named_modules():
|
for fqn, module in prepared_for_callibrate_model.named_modules():
|
||||||
# check if module is a supported module
|
# check if module is a supported module
|
||||||
is_in_include_list = sum(list(map(lambda x: isinstance(module, x), mods_to_check))) > 0
|
is_in_include_list = sum([isinstance(module, x) for x in mods_to_check]) > 0
|
||||||
|
|
||||||
if is_in_include_list:
|
if is_in_include_list:
|
||||||
# make sure it has the observer attribute
|
# make sure it has the observer attribute
|
||||||
|
|
|
||||||
|
|
@ -2398,10 +2398,8 @@ class TestAutograd(TestCase):
|
||||||
y = torch.randn((3, 3), requires_grad=True)
|
y = torch.randn((3, 3), requires_grad=True)
|
||||||
MyFunction.apply(x, y).sum().backward()
|
MyFunction.apply(x, y).sum().backward()
|
||||||
|
|
||||||
has_deprecated = map(lambda warn:
|
has_deprecated = ('deprecated' in str(warn) and
|
||||||
'deprecated' in str(warn) and
|
'saved_variables' in str(warn) for warn in warns)
|
||||||
'saved_variables' in str(warn),
|
|
||||||
warns)
|
|
||||||
has_deprecated = reduce(lambda x, y: x or y, has_deprecated)
|
has_deprecated = reduce(lambda x, y: x or y, has_deprecated)
|
||||||
self.assertTrue(has_deprecated)
|
self.assertTrue(has_deprecated)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -498,7 +498,7 @@ class TestBinaryUfuncs(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _supported(dtypes):
|
def _supported(dtypes):
|
||||||
return all(map(lambda x: x in supported_dtypes, dtypes))
|
return all((x in supported_dtypes for x in dtypes))
|
||||||
|
|
||||||
# int x int type promotion
|
# int x int type promotion
|
||||||
if _supported((torch.int16, torch.int32, torch.int64)):
|
if _supported((torch.int16, torch.int32, torch.int64)):
|
||||||
|
|
|
||||||
|
|
@ -6793,7 +6793,7 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
||||||
for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]:
|
for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]:
|
||||||
expected_value = []
|
expected_value = []
|
||||||
actual_value = fn(full_tensor)
|
actual_value = fn(full_tensor)
|
||||||
for full_idx in itertools.product(*map(lambda x: list(range(x)), batchdims)):
|
for full_idx in itertools.product(*(list(range(x)) for x in batchdims)):
|
||||||
expected_value.append(fn(full_tensor[full_idx]))
|
expected_value.append(fn(full_tensor[full_idx]))
|
||||||
|
|
||||||
if fn == torch.slogdet or fn == torch.linalg.slogdet:
|
if fn == torch.slogdet or fn == torch.linalg.slogdet:
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,7 @@ class TestCommon(TestCase):
|
||||||
if isinstance(result, torch.Tensor):
|
if isinstance(result, torch.Tensor):
|
||||||
self.assertTrue(result.device == cuda_device)
|
self.assertTrue(result.device == cuda_device)
|
||||||
elif is_iterable_of_tensors(result):
|
elif is_iterable_of_tensors(result):
|
||||||
self.assertTrue(all(map(lambda t: t.device == cuda_device, result)))
|
self.assertTrue(all((t.device == cuda_device for t in result)))
|
||||||
else:
|
else:
|
||||||
self.skipTest(
|
self.skipTest(
|
||||||
"Skipped! Only supports single tensor or iterable of tensor outputs."
|
"Skipped! Only supports single tensor or iterable of tensor outputs."
|
||||||
|
|
@ -711,7 +711,7 @@ class TestCommon(TestCase):
|
||||||
return (out.stride(),)
|
return (out.stride(),)
|
||||||
|
|
||||||
# assumes (see above) that out is an iterable of tensors
|
# assumes (see above) that out is an iterable of tensors
|
||||||
return tuple(map(lambda t: t.stride(), out))
|
return tuple((t.stride() for t in out))
|
||||||
|
|
||||||
# Extracts data pointers from a tensor or iterable of tensors into a tuple
|
# Extracts data pointers from a tensor or iterable of tensors into a tuple
|
||||||
# NOTE: only extracts on the CPU and CUDA device types since some
|
# NOTE: only extracts on the CPU and CUDA device types since some
|
||||||
|
|
@ -724,7 +724,7 @@ class TestCommon(TestCase):
|
||||||
return (out.data_ptr(),)
|
return (out.data_ptr(),)
|
||||||
|
|
||||||
# assumes (see above) that out is an iterable of tensors
|
# assumes (see above) that out is an iterable of tensors
|
||||||
return tuple(map(lambda t: t.data_ptr(), out))
|
return tuple((t.data_ptr() for t in out))
|
||||||
|
|
||||||
@suppress_warnings
|
@suppress_warnings
|
||||||
def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
|
def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
|
||||||
|
|
@ -831,7 +831,7 @@ class TestCommon(TestCase):
|
||||||
return (out.stride(),)
|
return (out.stride(),)
|
||||||
|
|
||||||
# assumes (see above) that out is an iterable of tensors
|
# assumes (see above) that out is an iterable of tensors
|
||||||
return tuple(map(lambda t: t.stride(), out))
|
return tuple((t.stride() for t in out))
|
||||||
|
|
||||||
# Extracts data pointers from a tensor or iterable of tensors into a tuple
|
# Extracts data pointers from a tensor or iterable of tensors into a tuple
|
||||||
# NOTE: only extracts on the CPU and CUDA device types since some
|
# NOTE: only extracts on the CPU and CUDA device types since some
|
||||||
|
|
@ -844,7 +844,7 @@ class TestCommon(TestCase):
|
||||||
return (out.data_ptr(),)
|
return (out.data_ptr(),)
|
||||||
|
|
||||||
# assumes (see above) that out is an iterable of tensors
|
# assumes (see above) that out is an iterable of tensors
|
||||||
return tuple(map(lambda t: t.data_ptr(), out))
|
return tuple((t.data_ptr() for t in out))
|
||||||
|
|
||||||
def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
|
def _compare_out(transform, *, compare_strides_and_data_ptrs=True):
|
||||||
out = _apply_out_transform(transform, expected)
|
out = _apply_out_transform(transform, expected)
|
||||||
|
|
|
||||||
|
|
@ -1584,7 +1584,7 @@ class TestSparseCSR(TestCase):
|
||||||
return t.cpu().resolve_conj().numpy()
|
return t.cpu().resolve_conj().numpy()
|
||||||
|
|
||||||
res = _npref_block_addmm_addmv(
|
res = _npref_block_addmm_addmv(
|
||||||
*map(lambda t: prep_input(t), (c, a, b)),
|
*(prep_input(t) for t in (c, a, b)),
|
||||||
alpha,
|
alpha,
|
||||||
beta
|
beta
|
||||||
)
|
)
|
||||||
|
|
@ -2406,7 +2406,7 @@ class TestSparseCSR(TestCase):
|
||||||
output.backward(covector)
|
output.backward(covector)
|
||||||
|
|
||||||
# Compute dense result and compare with sparse result
|
# Compute dense result and compare with sparse result
|
||||||
c1, a1, b1 = map(lambda x: x.detach().to_dense().requires_grad_(True), [c, a, b])
|
c1, a1, b1 = (x.detach().to_dense().requires_grad_(True) for x in [c, a, b])
|
||||||
dense_output = sample.kwargs['alpha'] * (a1 @ b1) * torch.ones_like(c).to_dense() + sample.kwargs['beta'] * c1
|
dense_output = sample.kwargs['alpha'] * (a1 @ b1) * torch.ones_like(c).to_dense() + sample.kwargs['beta'] * c1
|
||||||
self.assertEqual(output, dense_output)
|
self.assertEqual(output, dense_output)
|
||||||
dense_covector = covector.to_dense()
|
dense_covector = covector.to_dense()
|
||||||
|
|
|
||||||
|
|
@ -1138,8 +1138,7 @@ class TestTypePromotion(TestCase):
|
||||||
exp_type = expected_type(inp, min_v, max_v)
|
exp_type = expected_type(inp, min_v, max_v)
|
||||||
if exp_type != torch.bool:
|
if exp_type != torch.bool:
|
||||||
actual = torch.clamp(inp, min_v, max_v)
|
actual = torch.clamp(inp, min_v, max_v)
|
||||||
inps = list(map(lambda x: x.to(exp_type) if isinstance(x, torch.Tensor) else x,
|
inps = [x.to(exp_type) if isinstance(x, torch.Tensor) else x for x in (inp, min_v, max_v)]
|
||||||
(inp, min_v, max_v)))
|
|
||||||
expected = torch.clamp(inps[0], inps[1], inps[2])
|
expected = torch.clamp(inps[0], inps[1], inps[2])
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
if inp.dtype in floating_types() or exp_type == inp.dtype:
|
if inp.dtype in floating_types() or exp_type == inp.dtype:
|
||||||
|
|
@ -1151,8 +1150,7 @@ class TestTypePromotion(TestCase):
|
||||||
exp_type = expected_type(inp, val)
|
exp_type = expected_type(inp, val)
|
||||||
if exp_type != torch.bool:
|
if exp_type != torch.bool:
|
||||||
actual = torch.clamp_min(inp, val)
|
actual = torch.clamp_min(inp, val)
|
||||||
inps = list(map(lambda x: x.to(exp_type) if isinstance(x, torch.Tensor) else x,
|
inps = [x.to(exp_type) if isinstance(x, torch.Tensor) else x for x in (inp, val)]
|
||||||
(inp, val)))
|
|
||||||
expected = torch.clamp_min(inps[0], inps[1])
|
expected = torch.clamp_min(inps[0], inps[1])
|
||||||
self.assertEqual(actual.dtype, exp_type)
|
self.assertEqual(actual.dtype, exp_type)
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
|
||||||
|
|
@ -184,7 +184,7 @@ def _parse_reveals(file: IO[str]) -> List[str]:
|
||||||
string = file.read().replace("*", "")
|
string = file.read().replace("*", "")
|
||||||
|
|
||||||
# Grab all `# E:`-based comments
|
# Grab all `# E:`-based comments
|
||||||
comments_array = list(map(lambda str: str.partition(" # E: ")[2], string.split("\n")))
|
comments_array = [str.partition(" # E: ")[2] for str in string.split("\n")]
|
||||||
comments = "/n".join(comments_array)
|
comments = "/n".join(comments_array)
|
||||||
|
|
||||||
# Only search for the `{*}` pattern within comments,
|
# Only search for the `{*}` pattern within comments,
|
||||||
|
|
|
||||||
|
|
@ -443,8 +443,8 @@ def gen_autograd_functions_lib(
|
||||||
# get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
|
# get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
|
||||||
# infos with the diff dispatchkeys but the same name will still be in the same shard.
|
# infos with the diff dispatchkeys but the same name will still be in the same shard.
|
||||||
infos = get_infos_with_derivatives_list(differentiability_infos)
|
infos = get_infos_with_derivatives_list(differentiability_infos)
|
||||||
declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
|
declarations = [process_function(f, FUNCTION_DECLARATION) for f in infos]
|
||||||
definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))
|
definitions = [process_function(f, FUNCTION_DEFINITION) for f in infos]
|
||||||
|
|
||||||
file_basename = "Functions"
|
file_basename = "Functions"
|
||||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||||
|
|
|
||||||
|
|
@ -159,9 +159,9 @@ _SKIP_PYTHON_BINDINGS = [
|
||||||
"_nested_view_from_buffer_copy_out",
|
"_nested_view_from_buffer_copy_out",
|
||||||
]
|
]
|
||||||
|
|
||||||
SKIP_PYTHON_BINDINGS = list(
|
SKIP_PYTHON_BINDINGS = [
|
||||||
map(lambda pattern: re.compile(rf"^{pattern}$"), _SKIP_PYTHON_BINDINGS)
|
re.compile(rf"^{pattern}$") for pattern in _SKIP_PYTHON_BINDINGS
|
||||||
)
|
]
|
||||||
|
|
||||||
# These function signatures are not exposed to Python. Note that this signature
|
# These function signatures are not exposed to Python. Note that this signature
|
||||||
# list does not support regex.
|
# list does not support regex.
|
||||||
|
|
@ -864,7 +864,7 @@ def method_impl(
|
||||||
name=name,
|
name=name,
|
||||||
pycname=pycname,
|
pycname=pycname,
|
||||||
method_header=method_header,
|
method_header=method_header,
|
||||||
max_args=max(map(lambda o: o.signature.arguments_count(), overloads)),
|
max_args=max((o.signature.arguments_count() for o in overloads)),
|
||||||
signatures=signatures,
|
signatures=signatures,
|
||||||
traceable=traceable,
|
traceable=traceable,
|
||||||
check_has_torch_function=gen_has_torch_function_check(
|
check_has_torch_function=gen_has_torch_function_check(
|
||||||
|
|
@ -1216,7 +1216,7 @@ def sort_overloads(
|
||||||
del larger_than[j]
|
del larger_than[j]
|
||||||
sorted_ids.append(j)
|
sorted_ids.append(j)
|
||||||
|
|
||||||
return list(map(lambda x: grouped_overloads[x], sorted_ids))
|
return [grouped_overloads[x] for x in sorted_ids]
|
||||||
|
|
||||||
|
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||||
|
|
@ -1250,9 +1250,9 @@ def emit_single_dispatch(
|
||||||
# dispatch lambda signature
|
# dispatch lambda signature
|
||||||
name = cpp.name(f.func)
|
name = cpp.name(f.func)
|
||||||
lambda_formals = ", ".join(
|
lambda_formals = ", ".join(
|
||||||
map(
|
(
|
||||||
lambda a: f"{a.type_str} {a.name}",
|
f"{a.type_str} {a.name}"
|
||||||
dispatch_lambda_args(ps, f, symint=symint),
|
for a in dispatch_lambda_args(ps, f, symint=symint)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
lambda_return = dispatch_lambda_return_str(f)
|
lambda_return = dispatch_lambda_return_str(f)
|
||||||
|
|
|
||||||
|
|
@ -149,12 +149,7 @@ def main(argv: List[Any]) -> None:
|
||||||
model_dict = yaml.safe_load(model_file)
|
model_dict = yaml.safe_load(model_file)
|
||||||
model_dicts.append(model_dict)
|
model_dicts.append(model_dict)
|
||||||
|
|
||||||
selective_builders = list(
|
selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts]
|
||||||
map(
|
|
||||||
lambda m: SelectiveBuilder.from_yaml_dict(m),
|
|
||||||
model_dicts,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# While we have the model_dicts generate the supported mobile models api
|
# While we have the model_dicts generate the supported mobile models api
|
||||||
gen_supported_mobile_models(model_dicts, options.output_dir)
|
gen_supported_mobile_models(model_dicts, options.output_dir)
|
||||||
|
|
|
||||||
|
|
@ -67,9 +67,7 @@ def get_selected_kernel_dtypes_code(
|
||||||
):
|
):
|
||||||
body_parts = []
|
body_parts = []
|
||||||
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
|
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
|
||||||
conditions = list(
|
conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
|
||||||
map(lambda x: "scalar_type == at::ScalarType::" + x, dtypes)
|
|
||||||
)
|
|
||||||
body_parts.append(
|
body_parts.append(
|
||||||
if_condition_template.substitute(
|
if_condition_template.substitute(
|
||||||
kernel_tag_name=kernel_tag,
|
kernel_tag_name=kernel_tag,
|
||||||
|
|
|
||||||
|
|
@ -316,11 +316,11 @@ class VariableBuilder:
|
||||||
elif istype(
|
elif istype(
|
||||||
value, (dict, collections.defaultdict, collections.OrderedDict)
|
value, (dict, collections.defaultdict, collections.OrderedDict)
|
||||||
) and all(
|
) and all(
|
||||||
map(
|
(
|
||||||
lambda k: ConstantVariable.is_literal(k)
|
ConstantVariable.is_literal(k)
|
||||||
or self.tensor_can_be_dict_key(k)
|
or self.tensor_can_be_dict_key(k)
|
||||||
or isinstance(k, enum.Enum),
|
or isinstance(k, enum.Enum)
|
||||||
value.keys(),
|
for k in value.keys()
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if not value and self.get_source().is_nn_module():
|
if not value and self.get_source().is_nn_module():
|
||||||
|
|
|
||||||
|
|
@ -2348,7 +2348,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
|
||||||
[isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards]
|
[isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards]
|
||||||
)
|
)
|
||||||
# See Note [Detaching saved tensors in AOTAutograd]
|
# See Note [Detaching saved tensors in AOTAutograd]
|
||||||
ctx.save_for_backward(*map(lambda x: x.detach() if x._is_view() else x, tensors_saved_for_backwards))
|
ctx.save_for_backward(*(x.detach() if x._is_view() else x for x in tensors_saved_for_backwards))
|
||||||
symint_outs = fw_outs[-num_symints_saved_for_bw:]
|
symint_outs = fw_outs[-num_symints_saved_for_bw:]
|
||||||
assert all(
|
assert all(
|
||||||
[
|
[
|
||||||
|
|
@ -2360,7 +2360,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
|
||||||
else:
|
else:
|
||||||
tensors_saved_for_backwards = fw_outs[num_forward_returns:]
|
tensors_saved_for_backwards = fw_outs[num_forward_returns:]
|
||||||
# See Note [Detaching saved tensors in AOTAutograd]
|
# See Note [Detaching saved tensors in AOTAutograd]
|
||||||
ctx.save_for_backward(*map(lambda x: x.detach() if x._is_view() else x, tensors_saved_for_backwards))
|
ctx.save_for_backward(*(x.detach() if x._is_view() else x for x in tensors_saved_for_backwards))
|
||||||
ctx.symints = []
|
ctx.symints = []
|
||||||
|
|
||||||
raw_returns = fw_outs[0:num_forward_returns]
|
raw_returns = fw_outs[0:num_forward_returns]
|
||||||
|
|
|
||||||
|
|
@ -544,7 +544,7 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
|
||||||
# Iterate and concat the jacobians of different
|
# Iterate and concat the jacobians of different
|
||||||
# inputs.
|
# inputs.
|
||||||
for idx in range(len(flat_primals)):
|
for idx in range(len(flat_primals)):
|
||||||
r = tuple(map(lambda r_: r_[idx], chunked_results))
|
r = tuple((r_[idx] for r_ in chunked_results))
|
||||||
flat_results.append(torch.cat(r, 0))
|
flat_results.append(torch.cat(r, 0))
|
||||||
|
|
||||||
return flat_results
|
return flat_results
|
||||||
|
|
@ -567,7 +567,7 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
|
||||||
for t in flat_basis_chunk:
|
for t in flat_basis_chunk:
|
||||||
assert t.size(0) == 1
|
assert t.size(0) == 1
|
||||||
|
|
||||||
flat_basis_chunk = list(map(lambda t: torch.squeeze(t, 0), flat_basis_chunk))
|
flat_basis_chunk = [torch.squeeze(t, 0) for t in flat_basis_chunk]
|
||||||
|
|
||||||
basis = tree_unflatten(flat_basis_chunk, output_spec)
|
basis = tree_unflatten(flat_basis_chunk, output_spec)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -397,7 +397,7 @@ class SizeVarAllocator:
|
||||||
|
|
||||||
def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
|
def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
|
||||||
strides = tuple(
|
strides = tuple(
|
||||||
map(lambda x: abs(x), self.stride_hints(index, vars))
|
map(abs, self.stride_hints(index, vars))
|
||||||
) # lambda to placate mypy
|
) # lambda to placate mypy
|
||||||
order = list(range(len(strides)))
|
order = list(range(len(strides)))
|
||||||
order.sort(key=lambda x: (strides[x] == 0, strides[x]))
|
order.sort(key=lambda x: (strides[x] == 0, strides[x]))
|
||||||
|
|
|
||||||
|
|
@ -346,7 +346,7 @@ def _broadcast_shapes(*_shapes):
|
||||||
def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
|
def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
|
||||||
# Computes common shape
|
# Computes common shape
|
||||||
common_shape = _broadcast_shapes(
|
common_shape = _broadcast_shapes(
|
||||||
*map(lambda t: t.shape if isinstance(t, TensorLike) else None, args)
|
*(t.shape if isinstance(t, TensorLike) else None for t in args)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __maybe_broadcast(x, shape):
|
def __maybe_broadcast(x, shape):
|
||||||
|
|
|
||||||
|
|
@ -292,7 +292,7 @@ class PerChannelDetector(DetectorBase):
|
||||||
# get the fully qualified name and check if in list of modules to include and list of modules to ignore
|
# get the fully qualified name and check if in list of modules to include and list of modules to ignore
|
||||||
for fqn, module in model.named_modules():
|
for fqn, module in model.named_modules():
|
||||||
|
|
||||||
is_in_include_list = sum(list(map(lambda x: isinstance(module, x), self.supported_modules))) > 0
|
is_in_include_list = sum([isinstance(module, x) for x in self.supported_modules]) > 0
|
||||||
|
|
||||||
# check if the module per_channel is supported
|
# check if the module per_channel is supported
|
||||||
# based on backend
|
# based on backend
|
||||||
|
|
@ -517,10 +517,10 @@ class DynamicStaticDetector(DetectorBase):
|
||||||
Returns True if the module is supported by observer, False otherwise
|
Returns True if the module is supported by observer, False otherwise
|
||||||
"""
|
"""
|
||||||
# check to see if module is of a supported type
|
# check to see if module is of a supported type
|
||||||
is_supported_type = sum(list(map(lambda x: isinstance(module, x), self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED))) > 0
|
is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0
|
||||||
|
|
||||||
# check if it will be supported
|
# check if it will be supported
|
||||||
future_supported_type = sum(list(map(lambda x: isinstance(module, x), self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED))) > 0
|
future_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED]) > 0
|
||||||
|
|
||||||
# supported
|
# supported
|
||||||
supported = is_supported_type or future_supported_type
|
supported = is_supported_type or future_supported_type
|
||||||
|
|
@ -578,7 +578,7 @@ class DynamicStaticDetector(DetectorBase):
|
||||||
post_obs_dist_classif = self.STATIONARY_STR if post_stat > self.tolerance else self.NON_STATIONARY_STR
|
post_obs_dist_classif = self.STATIONARY_STR if post_stat > self.tolerance else self.NON_STATIONARY_STR
|
||||||
|
|
||||||
# check if current support or future support
|
# check if current support or future support
|
||||||
is_supported_type = sum(list(map(lambda x: isinstance(module, x), self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED))) > 0
|
is_supported_type = sum([isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED]) > 0
|
||||||
|
|
||||||
# store the set of important information for this module
|
# store the set of important information for this module
|
||||||
module_info = {
|
module_info = {
|
||||||
|
|
@ -791,7 +791,7 @@ class InputWeightEqualizationDetector(DetectorBase):
|
||||||
Returns True if the module is supported by observer, False otherwise
|
Returns True if the module is supported by observer, False otherwise
|
||||||
"""
|
"""
|
||||||
# check to see if module is of a supported type
|
# check to see if module is of a supported type
|
||||||
is_supported_type = sum(list(map(lambda x: type(module) is x, self.SUPPORTED_MODULES))) > 0
|
is_supported_type = sum([type(module) is x for x in self.SUPPORTED_MODULES]) > 0
|
||||||
|
|
||||||
# this is check for observer insertion
|
# this is check for observer insertion
|
||||||
if insert:
|
if insert:
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ def _cast(value, dtype):
|
||||||
elif isinstance(value, collections.abc.Mapping):
|
elif isinstance(value, collections.abc.Mapping):
|
||||||
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
|
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
|
||||||
elif isinstance(value, collections.abc.Iterable):
|
elif isinstance(value, collections.abc.Iterable):
|
||||||
iterable = map(lambda v: _cast(v, dtype), value)
|
iterable = (_cast(v, dtype) for v in value)
|
||||||
if isinstance(value, (list, tuple)):
|
if isinstance(value, (list, tuple)):
|
||||||
return type(value)(iterable)
|
return type(value)(iterable)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -259,7 +259,7 @@ class _OverlapInfo:
|
||||||
assert (
|
assert (
|
||||||
len(self.broadcast_handles) == self.num_bucket_assignments
|
len(self.broadcast_handles) == self.num_bucket_assignments
|
||||||
), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
|
), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
|
||||||
_ = list(map(lambda x: x.wait(), self.broadcast_handles))
|
_ = [x.wait() for x in self.broadcast_handles]
|
||||||
self.broadcast_handles.clear()
|
self.broadcast_handles.clear()
|
||||||
|
|
||||||
def clear_per_iter_info(self) -> None:
|
def clear_per_iter_info(self) -> None:
|
||||||
|
|
@ -807,7 +807,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||||
handles = []
|
handles = []
|
||||||
for rank in range(self.world_size):
|
for rank in range(self.world_size):
|
||||||
handles.extend(self._broadcast_params_from_rank(rank))
|
handles.extend(self._broadcast_params_from_rank(rank))
|
||||||
_ = list(map(lambda x: x.wait(), handles))
|
_ = [x.wait() for x in handles]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _device_to_params_per_rank(
|
def _device_to_params_per_rank(
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,7 @@ def _masked_tensor_str(data, mask, formatter):
|
||||||
for d in data
|
for d in data
|
||||||
]
|
]
|
||||||
max_len = max(
|
max_len = max(
|
||||||
map(lambda x: 8 if x[1] else len(x[0]), zip(formatted_elements, ~mask))
|
(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask))
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
"["
|
"["
|
||||||
|
|
|
||||||
|
|
@ -885,7 +885,7 @@ def arange(g: jit_utils.GraphContext, *args):
|
||||||
dtype = symbolic_helper._maybe_get_const(dtype, "i")
|
dtype = symbolic_helper._maybe_get_const(dtype, "i")
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
if len(args) == 2 and all(map(lambda val: isinstance(val, int), args)):
|
if len(args) == 2 and all((isinstance(val, int) for val in args)):
|
||||||
# aten::arange(Scalar start, Scalar end)
|
# aten::arange(Scalar start, Scalar end)
|
||||||
dtype = torch.int64
|
dtype = torch.int64
|
||||||
# Start index.
|
# Start index.
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ class _StorageBase:
|
||||||
return str(self)
|
return str(self)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(map(lambda i: self[i], range(self.size())))
|
return iter((self[i] for i in range(self.size())))
|
||||||
|
|
||||||
def __copy__(self):
|
def __copy__(self):
|
||||||
return self.clone()
|
return self.clone()
|
||||||
|
|
@ -725,7 +725,7 @@ class TypedStorage:
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
_warn_typed_storage_removal()
|
_warn_typed_storage_removal()
|
||||||
return iter(map(lambda i: self[i], range(self.size())))
|
return iter((self[i] for i in range(self.size())))
|
||||||
|
|
||||||
def __copy__(self):
|
def __copy__(self):
|
||||||
_warn_typed_storage_removal()
|
_warn_typed_storage_removal()
|
||||||
|
|
|
||||||
|
|
@ -6827,7 +6827,7 @@ def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
|
|
||||||
if mask_t.sum() == 0:
|
if mask_t.sum() == 0:
|
||||||
def random_index(shape):
|
def random_index(shape):
|
||||||
return tuple(map(lambda max_idx: random.randrange(0, max_idx), shape))
|
return tuple((random.randrange(0, max_idx) for max_idx in shape))
|
||||||
|
|
||||||
mask_t[random_index(mask_t.shape)] = True
|
mask_t[random_index(mask_t.shape)] = True
|
||||||
return mask_t
|
return mask_t
|
||||||
|
|
|
||||||
|
|
@ -5214,7 +5214,7 @@ class ModuleTest(TestBase):
|
||||||
type_map = {torch.double: torch.float}
|
type_map = {torch.double: torch.float}
|
||||||
cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
|
cpu_input_tuple = cpu_input if isinstance(cpu_input, tuple) else (cpu_input,)
|
||||||
|
|
||||||
is_any_input_complex = any(map(lambda t: isinstance(t, torch.Tensor) and t.dtype.is_complex, cpu_input_tuple))
|
is_any_input_complex = any((isinstance(t, torch.Tensor) and t.dtype.is_complex for t in cpu_input_tuple))
|
||||||
|
|
||||||
gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
|
gpu_input_tuple = to_gpu(cpu_input_tuple, type_map=type_map)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -869,9 +869,9 @@ class OpInfo:
|
||||||
|
|
||||||
# Attribute to verify dynamic_dtypes are used.
|
# Attribute to verify dynamic_dtypes are used.
|
||||||
self.dynamic_dtypes = any(
|
self.dynamic_dtypes = any(
|
||||||
map(
|
(
|
||||||
lambda dtypes: isinstance(dtypes, utils._dynamic_dispatch_dtypes),
|
isinstance(dtypes, utils._dynamic_dispatch_dtypes)
|
||||||
dtypes_args,
|
for dtypes in dtypes_args
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1661,7 +1661,7 @@ def generate_elementwise_binary_small_value_tensors(
|
||||||
complex_vals = product(_float_vals, _float_vals)
|
complex_vals = product(_float_vals, _float_vals)
|
||||||
# Note the use of list is required here or the map generator will be
|
# Note the use of list is required here or the map generator will be
|
||||||
# emptied by the following product and it won't produce the desired cross-product
|
# emptied by the following product and it won't produce the desired cross-product
|
||||||
complex_vals = list(map(lambda x: complex(*x), complex_vals))
|
complex_vals = [complex(*x) for x in complex_vals]
|
||||||
prod = product(complex_vals, complex_vals)
|
prod = product(complex_vals, complex_vals)
|
||||||
elif dtype in (torch.int8, torch.int16, torch.int32, torch.int64):
|
elif dtype in (torch.int8, torch.int16, torch.int32, torch.int64):
|
||||||
prod = product(_int_vals, _int_vals)
|
prod = product(_int_vals, _int_vals)
|
||||||
|
|
@ -1701,7 +1701,7 @@ def generate_elementwise_binary_large_value_tensors(
|
||||||
complex_vals = product(_large_float_vals, _large_float_vals)
|
complex_vals = product(_large_float_vals, _large_float_vals)
|
||||||
# Note the use of list is required here or the map generator will be
|
# Note the use of list is required here or the map generator will be
|
||||||
# emptied by the following product and it won't produce the desired cross-product
|
# emptied by the following product and it won't produce the desired cross-product
|
||||||
complex_vals = list(map(lambda x: complex(*x), complex_vals))
|
complex_vals = [complex(*x) for x in complex_vals]
|
||||||
prod = product(complex_vals, complex_vals)
|
prod = product(complex_vals, complex_vals)
|
||||||
elif dtype in (torch.int16, torch.int32, torch.int64):
|
elif dtype in (torch.int16, torch.int32, torch.int64):
|
||||||
prod = product(_large_int_vals, _large_int_vals)
|
prod = product(_large_int_vals, _large_int_vals)
|
||||||
|
|
@ -1732,7 +1732,7 @@ def generate_elementwise_binary_extremal_value_tensors(
|
||||||
complex_vals = product(_float_extremals, _float_extremals)
|
complex_vals = product(_float_extremals, _float_extremals)
|
||||||
# Note the use of list is required here or the map generator will be
|
# Note the use of list is required here or the map generator will be
|
||||||
# emptied by the following product and it won't produce the desired cross-product
|
# emptied by the following product and it won't produce the desired cross-product
|
||||||
complex_vals = list(map(lambda x: complex(*x), complex_vals))
|
complex_vals = [complex(*x) for x in complex_vals]
|
||||||
prod = product(complex_vals, complex_vals)
|
prod = product(complex_vals, complex_vals)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported dtype!")
|
raise ValueError("Unsupported dtype!")
|
||||||
|
|
|
||||||
|
|
@ -315,7 +315,7 @@ class PythonOutArgument(PythonArgument):
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
elif size > 1:
|
elif size > 1:
|
||||||
if any(map(lambda a: not a.type.is_tensor_like(), outputs)):
|
if any((not a.type.is_tensor_like() for a in outputs)):
|
||||||
raise RuntimeError(f"Unsupported output type: {outputs}")
|
raise RuntimeError(f"Unsupported output type: {outputs}")
|
||||||
return PythonOutArgument(
|
return PythonOutArgument(
|
||||||
name="out",
|
name="out",
|
||||||
|
|
@ -390,9 +390,9 @@ class PythonSignature:
|
||||||
# signature_str_pyi().
|
# signature_str_pyi().
|
||||||
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
|
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
|
||||||
args = self.arguments(skip_outputs=skip_outputs)
|
args = self.arguments(skip_outputs=skip_outputs)
|
||||||
schema_formals: List[str] = list(
|
schema_formals: List[str] = [
|
||||||
map(lambda a: a.argument_str(method=self.method, symint=symint), args)
|
a.argument_str(method=self.method, symint=symint) for a in args
|
||||||
)
|
]
|
||||||
positional_argc = len(self.input_args)
|
positional_argc = len(self.input_args)
|
||||||
if len(schema_formals) > positional_argc:
|
if len(schema_formals) > positional_argc:
|
||||||
schema_formals.insert(positional_argc, "*")
|
schema_formals.insert(positional_argc, "*")
|
||||||
|
|
@ -401,9 +401,9 @@ class PythonSignature:
|
||||||
|
|
||||||
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
||||||
args = self.arguments(skip_outputs=skip_outputs)
|
args = self.arguments(skip_outputs=skip_outputs)
|
||||||
schema_formals: List[str] = list(
|
schema_formals: List[str] = [
|
||||||
map(lambda a: a.argument_str_pyi(method=self.method), args)
|
a.argument_str_pyi(method=self.method) for a in args
|
||||||
)
|
]
|
||||||
positional_argc = len(self.input_args)
|
positional_argc = len(self.input_args)
|
||||||
if len(schema_formals) > positional_argc:
|
if len(schema_formals) > positional_argc:
|
||||||
schema_formals.insert(positional_argc, "*")
|
schema_formals.insert(positional_argc, "*")
|
||||||
|
|
@ -418,9 +418,9 @@ class PythonSignature:
|
||||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
|
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
|
||||||
# only pyi uses vararg signatures
|
# only pyi uses vararg signatures
|
||||||
args = self.arguments(skip_outputs=skip_outputs)
|
args = self.arguments(skip_outputs=skip_outputs)
|
||||||
schema_formals: List[str] = list(
|
schema_formals: List[str] = [
|
||||||
map(lambda a: a.argument_str_pyi(method=self.method), args)
|
a.argument_str_pyi(method=self.method) for a in args
|
||||||
)
|
]
|
||||||
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
|
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
|
||||||
num_args = self.arguments_count()
|
num_args = self.arguments_count()
|
||||||
num_positionalargs = len(self.input_args)
|
num_positionalargs = len(self.input_args)
|
||||||
|
|
@ -478,9 +478,9 @@ class PythonSignatureDeprecated(PythonSignature):
|
||||||
|
|
||||||
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
||||||
args = self.arguments(skip_outputs=skip_outputs)
|
args = self.arguments(skip_outputs=skip_outputs)
|
||||||
schema_formals: List[str] = list(
|
schema_formals: List[str] = [
|
||||||
map(lambda a: a.argument_str_pyi(method=self.method, deprecated=True), args)
|
a.argument_str_pyi(method=self.method, deprecated=True) for a in args
|
||||||
)
|
]
|
||||||
positional_argc = len(self.input_args)
|
positional_argc = len(self.input_args)
|
||||||
if len(schema_formals) > positional_argc:
|
if len(schema_formals) > positional_argc:
|
||||||
schema_formals.insert(positional_argc, "*")
|
schema_formals.insert(positional_argc, "*")
|
||||||
|
|
@ -882,10 +882,10 @@ def signature_from_schema(
|
||||||
|
|
||||||
|
|
||||||
def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
||||||
if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)):
|
if len(returns) <= 1 or all((r.name is None for r in returns)):
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
if any(map(lambda r: r.name is None, returns)):
|
if any((r.name is None for r in returns)):
|
||||||
# When building on Windows, `PyStructSequence_UnnamedField` could not be
|
# When building on Windows, `PyStructSequence_UnnamedField` could not be
|
||||||
# resolved by the linker for some reason, which cause error in building:
|
# resolved by the linker for some reason, which cause error in building:
|
||||||
#
|
#
|
||||||
|
|
@ -897,7 +897,7 @@ def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
||||||
# or none of them.
|
# or none of them.
|
||||||
raise ValueError("Unnamed field is not supported by codegen")
|
raise ValueError("Unnamed field is not supported by codegen")
|
||||||
|
|
||||||
return list(map(lambda r: str(r.name), returns))
|
return [str(r.name) for r in returns]
|
||||||
|
|
||||||
|
|
||||||
def argument_type_str_pyi(t: Type) -> str:
|
def argument_type_str_pyi(t: Type) -> str:
|
||||||
|
|
@ -1157,7 +1157,7 @@ def dispatch_lambda_return_str(f: NativeFunction) -> str:
|
||||||
# mutable reference to temporary. Maybe we could assign it to a
|
# mutable reference to temporary. Maybe we could assign it to a
|
||||||
# variable itself.)
|
# variable itself.)
|
||||||
returns_without_annotation = tuple(
|
returns_without_annotation = tuple(
|
||||||
map(lambda r: Return(r.name, r.type, None), f.func.returns)
|
(Return(r.name, r.type, None) for r in f.func.returns)
|
||||||
)
|
)
|
||||||
return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
|
return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
|
||||||
if return_str not in SUPPORTED_RETURN_TYPES:
|
if return_str not in SUPPORTED_RETURN_TYPES:
|
||||||
|
|
@ -1189,7 +1189,7 @@ def cpp_dispatch_exprs(
|
||||||
exprs: Tuple[str, ...] = tuple()
|
exprs: Tuple[str, ...] = tuple()
|
||||||
if not isinstance(python_signature, PythonSignatureDeprecated):
|
if not isinstance(python_signature, PythonSignatureDeprecated):
|
||||||
# By default the exprs are consistent with the C++ signature.
|
# By default the exprs are consistent with the C++ signature.
|
||||||
exprs = tuple(map(lambda a: a.name, cpp_args))
|
exprs = tuple((a.name for a in cpp_args))
|
||||||
else:
|
else:
|
||||||
# For deprecated python signature we may need fill in some constants.
|
# For deprecated python signature we may need fill in some constants.
|
||||||
exprs = tuple(
|
exprs = tuple(
|
||||||
|
|
@ -1415,7 +1415,7 @@ def dispatch_lambda_exprs(
|
||||||
lambda_args_exprs["self"] = "self"
|
lambda_args_exprs["self"] = "self"
|
||||||
|
|
||||||
# 2. special packing/checking for TensorOptions.
|
# 2. special packing/checking for TensorOptions.
|
||||||
tensor_options_args_names = list(map(lambda a: a.name, ps.tensor_options_args))
|
tensor_options_args_names = [a.name for a in ps.tensor_options_args]
|
||||||
if has_toptions:
|
if has_toptions:
|
||||||
if f.func.is_out_fn():
|
if f.func.is_out_fn():
|
||||||
raise RuntimeError(f"{f.func}: tensor options with output arg")
|
raise RuntimeError(f"{f.func}: tensor options with output arg")
|
||||||
|
|
@ -1429,7 +1429,7 @@ def dispatch_lambda_exprs(
|
||||||
f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
|
f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
|
||||||
)
|
)
|
||||||
if not all(
|
if not all(
|
||||||
map(lambda a: a in tensor_options_args_names, TENSOR_OPTIONS_FIELDS.keys())
|
(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS.keys())
|
||||||
):
|
):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
|
f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
|
||||||
|
|
@ -1457,9 +1457,7 @@ torch::utils::maybe_initialize_cuda(options);
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{f.func}: dtype in tensor_options_args without output arg"
|
f"{f.func}: dtype in tensor_options_args without output arg"
|
||||||
)
|
)
|
||||||
if not all(
|
if not all((a in tensor_options_args_names for a in ("layout", "device"))):
|
||||||
map(lambda a: a in tensor_options_args_names, ("layout", "device"))
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{f.func}: incomplete tensor options for output check"
|
f"{f.func}: incomplete tensor options for output check"
|
||||||
)
|
)
|
||||||
|
|
@ -1478,6 +1476,6 @@ check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dt
|
||||||
)
|
)
|
||||||
|
|
||||||
return DispatchLambdaArgumentExprs(
|
return DispatchLambdaArgumentExprs(
|
||||||
exprs=tuple(map(lambda a: lambda_args_exprs[a.name], lambda_args)),
|
exprs=tuple((lambda_args_exprs[a.name] for a in lambda_args)),
|
||||||
inits=inits,
|
inits=inits,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class SelectiveBuildOperator:
|
||||||
if "debug_info" in op_info:
|
if "debug_info" in op_info:
|
||||||
di_list = op_info["debug_info"]
|
di_list = op_info["debug_info"]
|
||||||
assert isinstance(di_list, list)
|
assert isinstance(di_list, list)
|
||||||
debug_info = tuple(map(lambda x: str(x), di_list))
|
debug_info = tuple((str(x) for x in di_list))
|
||||||
|
|
||||||
return SelectiveBuildOperator(
|
return SelectiveBuildOperator(
|
||||||
name=op_name,
|
name=op_name,
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ class SelectiveBuilder:
|
||||||
di_list = data["debug_info"]
|
di_list = data["debug_info"]
|
||||||
assert isinstance(di_list, list)
|
assert isinstance(di_list, list)
|
||||||
|
|
||||||
debug_info = tuple(map(lambda x: str(x), di_list))
|
debug_info = tuple((str(x) for x in di_list))
|
||||||
|
|
||||||
operators = {}
|
operators = {}
|
||||||
operators_dict = data.get("operators", {})
|
operators_dict = data.get("operators", {})
|
||||||
|
|
@ -99,7 +99,7 @@ class SelectiveBuilder:
|
||||||
assert isinstance(kernel_metadata_dict, dict)
|
assert isinstance(kernel_metadata_dict, dict)
|
||||||
|
|
||||||
for k, v in kernel_metadata_dict.items():
|
for k, v in kernel_metadata_dict.items():
|
||||||
kernel_metadata[str(k)] = list(map(lambda dtype: str(dtype), v))
|
kernel_metadata[str(k)] = [str(dtype) for dtype in v]
|
||||||
|
|
||||||
custom_classes = data.get("custom_classes", [])
|
custom_classes = data.get("custom_classes", [])
|
||||||
custom_classes = set(custom_classes) # type: ignore[arg-type]
|
custom_classes = set(custom_classes) # type: ignore[arg-type]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user