mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][CI] bump ruff to 0.8.4 (#143753)
Changes: 1. Bump `ruff` from 0.7.4 to 0.8.4 2. Change `%`-formatted strings to f-string 3. Change arguments with the `__`-prefix to positional-only arguments with the `/` separator in function signature. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143753 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
dbbc81cb34
commit
b77406a9ec
|
|
@ -619,9 +619,11 @@ def build_torchaudio(
|
|||
if host.using_docker():
|
||||
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
|
||||
|
||||
host.run_cmd(f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
|
||||
host.run_cmd(
|
||||
f"cd audio && export FFMPEG_ROOT=$(pwd)/third_party/ffmpeg && export USE_FFMPEG=1 \
|
||||
&& ./packaging/ffmpeg/build.sh \
|
||||
&& {build_vars} python3 setup.py bdist_wheel")
|
||||
&& {build_vars} python3 setup.py bdist_wheel"
|
||||
)
|
||||
|
||||
wheel_name = host.list_dir("audio/dist")[0]
|
||||
embed_libgomp(host, use_conda, os.path.join("audio", "dist", wheel_name))
|
||||
|
|
|
|||
|
|
@ -109,8 +109,10 @@ def check_version(package: str) -> None:
|
|||
{release_matrix[module['name']]} for channel {channel}. But its {module_version}"
|
||||
)
|
||||
else:
|
||||
print(f"{module['name']} version actual: {module_version} expected: \
|
||||
{release_matrix[module['name']]} for channel {channel}.")
|
||||
print(
|
||||
f"{module['name']} version actual: {module_version} expected: \
|
||||
{release_matrix[module['name']]} for channel {channel}."
|
||||
)
|
||||
|
||||
else:
|
||||
print(f"Skip version check for channel {channel} as stable version is None")
|
||||
|
|
|
|||
|
|
@ -1474,7 +1474,7 @@ init_command = [
|
|||
'black==23.12.1',
|
||||
'usort==1.0.8.post1',
|
||||
'isort==5.13.2',
|
||||
'ruff==0.7.4', # sync with RUFF
|
||||
'ruff==0.8.4', # sync with RUFF
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
|
@ -1559,7 +1559,7 @@ init_command = [
|
|||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'ruff==0.7.4', # sync with PYFMT
|
||||
'ruff==0.8.4', # sync with PYFMT
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
|
|
|||
|
|
@ -109,17 +109,17 @@ def sweep(benchmark):
|
|||
|
||||
def print_header():
|
||||
local_print("\n")
|
||||
local_print("%22s" % "")
|
||||
for p in [50, 75, 90, 95]:
|
||||
local_print("%14s%10s" % ("sec/iter", "ex/sec"))
|
||||
local_print(" " * 22)
|
||||
for _ in [50, 75, 90, 95]:
|
||||
local_print(f"{'sec/iter':14s}{'ex/sec':10s}")
|
||||
local_print("\n")
|
||||
|
||||
def print_measurements(prefix, nelem, measurements):
|
||||
measurements = sorted(measurements)
|
||||
local_print("%8s:" % prefix)
|
||||
local_print(f"{prefix:8s}:")
|
||||
for p in [50, 75, 90, 95]:
|
||||
v = np.percentile(measurements, p)
|
||||
local_print(" p%02d: %1.3fs %6d/s" % (p, v, nelem / v))
|
||||
local_print(f" p{p:02d}: {v:1.3f}s {nelem / v:6d}/s")
|
||||
local_print("\n")
|
||||
|
||||
# Every process runs once by themselves to warm up (CUDA init, etc).
|
||||
|
|
@ -133,7 +133,7 @@ def sweep(benchmark):
|
|||
|
||||
# Multi-machine benchmarks
|
||||
for i in range(1, (dist.get_world_size() // 8) + 1):
|
||||
append_benchmark(" %dM/8G" % i, range(i * 8))
|
||||
append_benchmark(f" {i:d}M/8G", range(i * 8))
|
||||
|
||||
# Run benchmarks in order of increasing number of GPUs
|
||||
print_header()
|
||||
|
|
@ -239,7 +239,7 @@ def main():
|
|||
print()
|
||||
|
||||
torch.cuda.set_device(dist.get_rank() % 8)
|
||||
device = torch.device("cuda:%d" % (dist.get_rank() % 8))
|
||||
device = torch.device(f"cuda:{dist.get_rank() % 8:d}")
|
||||
|
||||
benchmarks = []
|
||||
if args.model:
|
||||
|
|
|
|||
|
|
@ -314,7 +314,7 @@ class DeepSpeech(nn.Module):
|
|||
rnn_type=rnn_type,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
rnns.append(("%d" % (x + 1), rnn))
|
||||
rnns.append((f"{x + 1:d}", rnn))
|
||||
self.rnns = nn.Sequential(OrderedDict(rnns))
|
||||
self.lookahead = (
|
||||
nn.Sequential(
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ Works only with Python3.\n A few examples:
|
|||
cpu_count += 1
|
||||
if cpu_count > 1:
|
||||
raise ValueError(
|
||||
"more than one CPU device is not allowed: %d" % (cpu_count)
|
||||
f"more than one CPU device is not allowed: {cpu_count:d}"
|
||||
)
|
||||
if device == "cpu":
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -14,34 +14,34 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||
|
||||
if InType == "float":
|
||||
code.append(
|
||||
" vop%d = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (%d)), vop%d);" # noqa
|
||||
% (regid, regid, regid)
|
||||
f" vop{regid:d} = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + ({regid:d})), vop{regid:d});" # noqa
|
||||
|
||||
)
|
||||
elif InType == "at::Half":
|
||||
code.append(
|
||||
" vop%d = _mm256_fmadd_ps(\n"
|
||||
f" vop{regid:d} = _mm256_fmadd_ps(\n"
|
||||
" vwgt,\n"
|
||||
" _mm256_cvtph_ps(\n"
|
||||
" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))),\n" # noqa
|
||||
" vop%d);" % (regid, regid, regid)
|
||||
f" _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n" # noqa
|
||||
f" vop{regid:d});"
|
||||
)
|
||||
elif InType == "at::BFloat16":
|
||||
code.append(
|
||||
" vop%d = _mm256_fmadd_ps(\n"
|
||||
f" vop{regid:d} = _mm256_fmadd_ps(\n"
|
||||
" vwgt,\n"
|
||||
" _mm256_castsi256_ps(_mm256_slli_epi32(\n"
|
||||
" _mm256_cvtepu16_epi32(_mm_loadu_si128(\n"
|
||||
" reinterpret_cast<const __m128i*>(ip + (%d)))),\n"
|
||||
f" reinterpret_cast<const __m128i*>(ip + ({regid:d})))),\n"
|
||||
" 16)),\n" # noqa
|
||||
" vop%d);" % (regid, regid, regid)
|
||||
f" vop{regid:d});"
|
||||
)
|
||||
elif InType == "uint8_t":
|
||||
code.append(
|
||||
" vop%d = _mm256_fmadd_ps(\n"
|
||||
f" vop{regid:d} = _mm256_fmadd_ps(\n"
|
||||
" vwgt,\n"
|
||||
" _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(\n"
|
||||
" _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))),\n" # noqa
|
||||
" _mm256_add_ps(vop%d, vbio));" % (regid, regid, regid)
|
||||
f" _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + ({regid:d}))))),\n" # noqa
|
||||
f" _mm256_add_ps(vop{regid:d}, vbio));"
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
|
@ -49,12 +49,12 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||
if prefetch:
|
||||
code.append(
|
||||
" _mm_prefetch(\n"
|
||||
" reinterpret_cast<const char*>(&ip_next_T0[%d]), _MM_HINT_T0);"
|
||||
% (regid)
|
||||
f" reinterpret_cast<const char*>(&ip_next_T0[{regid:d}]), _MM_HINT_T0);"
|
||||
|
||||
)
|
||||
else:
|
||||
code.append(
|
||||
" // skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid)
|
||||
f" // skip unnecessary prefetch of (&ip_next_T0[{regid:d}])"
|
||||
)
|
||||
|
||||
return code
|
||||
|
|
@ -142,15 +142,13 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||
code.append(" }")
|
||||
code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
|
||||
|
||||
code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
|
||||
code.append(f" const {InType}* ip = &input[idx * fused_block_size];")
|
||||
code.append(
|
||||
" const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
|
||||
f" const {IndexType} next_T0 = (dataInd < index_size - prefdist_T0)\n"
|
||||
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
|
||||
" ? (dataInd + prefdist_T0)\n"
|
||||
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
|
||||
" : dataInd;".format(
|
||||
IndexType
|
||||
)
|
||||
" : dataInd;"
|
||||
)
|
||||
code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];")
|
||||
code.append(
|
||||
|
|
@ -160,8 +158,8 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused, use_offsets)
|
|||
)
|
||||
|
||||
code.append(
|
||||
" const {}* ip_next_T0 = "
|
||||
"&input[idx_pref_T0 * fused_block_size];".format(InType)
|
||||
f" const {InType}* ip_next_T0 = "
|
||||
"&input[idx_pref_T0 * fused_block_size];"
|
||||
)
|
||||
|
||||
for i in range(0, uf):
|
||||
|
|
@ -346,15 +344,13 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||
code.append(" }")
|
||||
code.append(" __m256 vwgt = _mm256_set1_ps(wgt);")
|
||||
|
||||
code.append(" const {}* ip = &input[idx * fused_block_size];".format(InType))
|
||||
code.append(f" const {InType}* ip = &input[idx * fused_block_size];")
|
||||
code.append(
|
||||
" const {} next_T0 = (dataInd < index_size - prefdist_T0)\n"
|
||||
f" const {IndexType} next_T0 = (dataInd < index_size - prefdist_T0)\n"
|
||||
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
|
||||
" ? (dataInd + prefdist_T0)\n"
|
||||
" // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)\n"
|
||||
" : dataInd;".format(
|
||||
IndexType
|
||||
)
|
||||
" : dataInd;"
|
||||
)
|
||||
code.append(" const " + IndexType + " idx_pref_T0 = indices[next_T0];")
|
||||
code.append(
|
||||
|
|
@ -363,8 +359,8 @@ def generic(IndexType, InType, OutType, use_weights, isa, fused, use_offsets):
|
|||
+ " }"
|
||||
)
|
||||
code.append(
|
||||
" const {}* ip_next_T0 = "
|
||||
"&input[idx_pref_T0 * fused_block_size];".format(InType)
|
||||
f" const {InType}* ip_next_T0 = "
|
||||
"&input[idx_pref_T0 * fused_block_size];"
|
||||
)
|
||||
|
||||
# compute and store main loop
|
||||
|
|
@ -459,7 +455,7 @@ code = []
|
|||
code.append("//// --------------------------")
|
||||
code.append("//// ATTENTION:")
|
||||
code.append("//// THIS CODE IS AUTOGENERATED")
|
||||
code.append("//// BY {}".format(sys.argv[0]))
|
||||
code.append(f"//// BY {sys.argv[0]}")
|
||||
code.append("//// DO NOT MODIFY!!!")
|
||||
code.append("//// --------------------------\n")
|
||||
|
||||
|
|
@ -474,13 +470,9 @@ for o in options:
|
|||
prefix = "Fused8BitRowwise" if opts.fused else ""
|
||||
code.append("template <bool IS_WEIGHT_POSITIONAL>")
|
||||
if opts.use_offsets:
|
||||
fn_base = "{}EmbeddingLookupIdx_{}_{}_{}".format(
|
||||
prefix, IndexTypeName, InTypeName, OutTypeName
|
||||
)
|
||||
fn_base = f"{prefix}EmbeddingLookupIdx_{IndexTypeName}_{InTypeName}_{OutTypeName}"
|
||||
else:
|
||||
fn_base = "{}EmbeddingLookup_{}_{}_{}".format(
|
||||
prefix, IndexTypeName, InTypeName, OutTypeName
|
||||
)
|
||||
fn_base = f"{prefix}EmbeddingLookup_{IndexTypeName}_{InTypeName}_{OutTypeName}"
|
||||
suffix = "__avx2_fma"
|
||||
fn = "static bool " + fn_base + suffix
|
||||
code.append(fn + "(")
|
||||
|
|
@ -509,7 +501,7 @@ for o in options:
|
|||
# an entire row, including scale and bias.
|
||||
offset = (8 // sizeof[InType]) if opts.fused else 0
|
||||
code.append(
|
||||
" const {} fused_block_size = block_size + {};".format(IndexType, offset)
|
||||
f" const {IndexType} fused_block_size = block_size + {offset};"
|
||||
)
|
||||
if opts.use_offsets:
|
||||
code.append(" int64_t dataInd = 0;")
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ for it in range(20000):
|
|||
opt.step()
|
||||
|
||||
if it % 100 == 0:
|
||||
print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
|
||||
print(f"Iteration {it:d} -- Outer Loss: {loss2:.4f}")
|
||||
losses.append(loss2.detach())
|
||||
|
||||
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ for it in range(20000):
|
|||
opt.step()
|
||||
|
||||
if it % 100 == 0:
|
||||
print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
|
||||
print(f"Iteration {it:d} -- Outer Loss: {loss2:.4f}")
|
||||
losses.append(loss2.detach())
|
||||
|
||||
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ for it in range(20000):
|
|||
opt.step()
|
||||
|
||||
if it % 100 == 0:
|
||||
print("Iteration %d -- Outer Loss: %.4f" % (it, loss2))
|
||||
print(f"Iteration {it:d} -- Outer Loss: {loss2:.4f}")
|
||||
losses.append(loss2.detach())
|
||||
|
||||
t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class AbstractTimeoutTest:
|
|||
else:
|
||||
yield f"file://{f.name}"
|
||||
f.close()
|
||||
yield "tcp://127.0.0.1:%d" % common.find_free_port()
|
||||
yield f"tcp://127.0.0.1:{common.find_free_port():d}"
|
||||
|
||||
def _test_default_store_timeout(self, backend):
|
||||
for init_method in self._init_methods():
|
||||
|
|
@ -339,7 +339,7 @@ class CommonDistributedDataParallelTest:
|
|||
gradient_as_bucket_view=False,
|
||||
):
|
||||
model = Net()
|
||||
device = devices[0] if devices else torch.device("cuda:%d" % self.rank)
|
||||
device = devices[0] if devices else torch.device(f"cuda:{self.rank:d}")
|
||||
ddp_model = DistributedDataParallel(
|
||||
copy.deepcopy(model).to(device),
|
||||
device_ids=device_ids,
|
||||
|
|
|
|||
|
|
@ -377,7 +377,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
self.assertEqual(
|
||||
torch.tensor([(i * self.world_size) + (i % self.world_size)]),
|
||||
inputs[i],
|
||||
msg=("Mismatch in iteration %d" % i),
|
||||
msg=(f"Mismatch in iteration {i:d}"),
|
||||
)
|
||||
|
||||
@requires_gloo()
|
||||
|
|
@ -482,7 +482,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
]
|
||||
),
|
||||
future_handle.value()[0],
|
||||
msg=("Mismatch in iteration %d" % i),
|
||||
msg=(f"Mismatch in iteration {i:d}"),
|
||||
)
|
||||
|
||||
@requires_gloo()
|
||||
|
|
@ -897,7 +897,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
self.assertEqual(
|
||||
torch.tensor([iter + root]),
|
||||
result[0],
|
||||
msg=("Mismatch in iteration %d for rank %d" % (iter, root)),
|
||||
msg=(f"Mismatch in iteration {iter:d} for rank {root:d}"),
|
||||
)
|
||||
|
||||
@requires_gloo()
|
||||
|
|
@ -1088,7 +1088,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
self.assertEqual(
|
||||
expected_outputs[iter],
|
||||
[result],
|
||||
msg=("Mismatch in iteration %d for root %d" % (iter, root)),
|
||||
msg=(f"Mismatch in iteration {iter:d} for root {root:d}"),
|
||||
)
|
||||
|
||||
@requires_gloo()
|
||||
|
|
@ -1223,7 +1223,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
self.assertEqual(
|
||||
expected_outputs[i],
|
||||
[result],
|
||||
msg=("Mismatch in iteration %d" % i),
|
||||
msg=(f"Mismatch in iteration {i:d}"),
|
||||
)
|
||||
|
||||
@requires_gloo()
|
||||
|
|
@ -1409,7 +1409,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||
]
|
||||
),
|
||||
result[0],
|
||||
msg=("Mismatch in iteration %d with root rank %d" % (iter, root)),
|
||||
msg=(f"Mismatch in iteration {iter:d} with root rank {root:d}"),
|
||||
)
|
||||
|
||||
@requires_gloo()
|
||||
|
|
@ -1916,7 +1916,7 @@ class DistributedDataParallelTest(
|
|||
torch.save(ddp_withload.state_dict(), checkpoint_path)
|
||||
|
||||
dist.barrier()
|
||||
map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
|
||||
map_location = {"cuda:0": f"cuda:{self.rank:d}"}
|
||||
ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
|
||||
|
||||
for model in [ddp_withload, model_withload]:
|
||||
|
|
@ -2360,7 +2360,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
backend="gloo", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
process_group = c10d.distributed_c10d._get_default_group()
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
backend = process_group._get_backend(device)
|
||||
backend.create_device(interface=LOOPBACK)
|
||||
ranks = list(range(self.world_size))
|
||||
|
|
|
|||
|
|
@ -506,7 +506,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||
# should not check on receive buffer
|
||||
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
c10d.init_process_group(
|
||||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
|
|
@ -529,7 +529,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||
@skip_if_lt_x_gpu(2)
|
||||
def test_nan_check(self):
|
||||
# Not expecting an error, NaN check should not make legit code fail
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
if not sm_is_or_higher_than(device, 8, 0):
|
||||
self.skipTest("bf16 requires sm >= 8.0")
|
||||
|
||||
|
|
@ -557,7 +557,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||
|
||||
pynvml.nvmlInit()
|
||||
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
x = torch.empty((1,), device=device)
|
||||
work = c10d.all_reduce(x, async_op=True)
|
||||
|
||||
|
|
@ -585,7 +585,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||
A helper for `test_extra_cuda_context`, if pynvml is NOT avaiable.
|
||||
If extra context is created, it would manifest into device 0's memory usage.
|
||||
"""
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
x = torch.empty((1,), device=device)
|
||||
# Rank 0 takes a snapshot before collective -- this snapshot should have
|
||||
# included rank 0's own context.
|
||||
|
|
@ -623,7 +623,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||
def test_extra_cuda_context(self):
|
||||
# Check if non-0 ranks would create extra CUDA context on device 0
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
c10d.init_process_group(
|
||||
backend="nccl",
|
||||
store=store,
|
||||
|
|
@ -3148,7 +3148,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
process_group = c10d.distributed_c10d._get_default_group()
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
ranks = [0, 1]
|
||||
for root_rank in ranks:
|
||||
self._test_broadcast_coalesced(process_group, device, root_rank)
|
||||
|
|
@ -3161,7 +3161,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
process_group = c10d.distributed_c10d._get_default_group()
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
tensors = [
|
||||
torch.full((60 + i,), self.rank + 1 + i, device=device, dtype=torch.float)
|
||||
for i in range(5)
|
||||
|
|
@ -3183,7 +3183,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
process_group = c10d.distributed_c10d._get_default_group()
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
tensors = [
|
||||
torch.full(
|
||||
(60 + i,), self.rank + 1 + i, device=device, dtype=torch.float
|
||||
|
|
@ -3204,7 +3204,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
process_group = c10d.distributed_c10d._get_default_group()
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
tensors = [
|
||||
torch.full((60 + i,), self.rank + 1 + i, device=device, dtype=torch.float)
|
||||
for i in range(5)
|
||||
|
|
@ -3748,7 +3748,7 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
|||
return
|
||||
|
||||
subgroup = self._init_two_pg2_subgroups(world_size)
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
input = torch.ones((10,), device=device) * self.rank
|
||||
if self.rank == 0 or self.rank == 2:
|
||||
gather_list = [torch.empty_like(input) for _ in range(subgroup.size())]
|
||||
|
|
@ -3839,7 +3839,7 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
|||
if self.rank >= world_size:
|
||||
return
|
||||
subgroup = self._init_two_pg2_subgroups(world_size)
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
x = torch.ones((10,), device=device) * self.rank
|
||||
if self.rank == 0 or self.rank == 2:
|
||||
expected = x + torch.ones((10,), device=device) * (self.rank + 1)
|
||||
|
|
@ -3863,7 +3863,7 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
|||
if self.rank >= world_size:
|
||||
return
|
||||
subgroup = self._init_two_pg2_subgroups(world_size)
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
if self.rank == 0 or self.rank == 2:
|
||||
x = torch.empty((10,), device=device)
|
||||
if async_op:
|
||||
|
|
@ -3899,7 +3899,7 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
|||
if self.rank >= world_size:
|
||||
return
|
||||
subgroup = self._init_two_pg2_subgroups(world_size)
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
ops = []
|
||||
if self.rank == 0 or self.rank == 2:
|
||||
x = torch.empty((10,), device=device)
|
||||
|
|
@ -3933,7 +3933,7 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
|||
if self.rank >= world_size:
|
||||
return
|
||||
subgroup = self._init_two_pg2_subgroups(world_size)
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
if self.rank == 0 or self.rank == 2:
|
||||
x = torch.empty((10,), device=device)
|
||||
if group_rank:
|
||||
|
|
@ -3967,7 +3967,7 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
|||
torch.cuda.set_device(self.rank)
|
||||
device = None
|
||||
else:
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
if self.rank == 0 or self.rank == 2:
|
||||
x = [{}]
|
||||
if group_rank:
|
||||
|
|
@ -4005,7 +4005,7 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
|||
torch.cuda.set_device(self.rank)
|
||||
device = None
|
||||
else:
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
if self.rank == 0 or self.rank == 2:
|
||||
x = [{}]
|
||||
if group_rank:
|
||||
|
|
@ -4037,7 +4037,7 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
|||
if self.rank >= world_size:
|
||||
return
|
||||
subgroup = self._init_two_pg2_subgroups(world_size)
|
||||
device = torch.device("cuda:%d" % self.rank)
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
x = torch.empty((10,), device=device)
|
||||
expected = torch.ones((10,), device=device) * self.rank
|
||||
if self.rank == 0 or self.rank == 2:
|
||||
|
|
|
|||
|
|
@ -751,7 +751,7 @@ class DistributedDataParallelTest(
|
|||
torch.save(ddp_withload.state_dict(), checkpoint_path)
|
||||
|
||||
dist.barrier()
|
||||
map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
|
||||
map_location = {"cuda:0": f"cuda:{self.rank:d}"}
|
||||
ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
|
||||
|
||||
for model in [ddp_withload, model_withload]:
|
||||
|
|
|
|||
|
|
@ -704,7 +704,7 @@ class RendezvousTCPTest(TestCase):
|
|||
def create_tcp_url(self):
|
||||
addr = DEFAULT_HOSTNAME
|
||||
port = common.find_free_port()
|
||||
url = "tcp://%s:%d?world_size=%d" % (addr, port, 1)
|
||||
url = f"tcp://{addr}:{port:d}?world_size=1"
|
||||
return url
|
||||
|
||||
def test_common_errors(self):
|
||||
|
|
|
|||
|
|
@ -538,7 +538,7 @@ class _DenseBlock(torch.nn.ModuleDict):
|
|||
) -> None:
|
||||
super().__init__()
|
||||
for i in range(num_layers):
|
||||
self.add_module("denselayer%d" % (i + 1), _Block())
|
||||
self.add_module(f"denselayer{i + 1:d}", _Block())
|
||||
|
||||
def forward(self, init_features):
|
||||
features = [init_features]
|
||||
|
|
@ -825,7 +825,7 @@ class EnumValues(torch.nn.ModuleDict):
|
|||
) -> None:
|
||||
super().__init__()
|
||||
for i in range(num_layers):
|
||||
self.add_module("denselayer%d" % (i + 1), _Block())
|
||||
self.add_module(f"denselayer{i + 1:d}", _Block())
|
||||
|
||||
def forward(self, init_features):
|
||||
features = [init_features]
|
||||
|
|
@ -842,7 +842,7 @@ class AccessByKeys(torch.nn.ModuleDict):
|
|||
) -> None:
|
||||
super().__init__()
|
||||
for i in range(num_layers):
|
||||
self.add_module("denselayer%d" % (i + 1), _Block())
|
||||
self.add_module(f"denselayer{i + 1:d}", _Block())
|
||||
|
||||
def forward(self, init_features):
|
||||
features = [init_features]
|
||||
|
|
@ -1036,7 +1036,7 @@ class ModuleGuardNameIsValid(torch.nn.ModuleDict):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
for i in range(2):
|
||||
self.add_module("l@yer-%d" % (i + 1), BasicModule())
|
||||
self.add_module(f"l@yer-{i + 1:d}", BasicModule())
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.values():
|
||||
|
|
|
|||
|
|
@ -765,7 +765,7 @@ class TestMkldnn(TestCase):
|
|||
y_bf16 = max_pool(x_bf16.to_mkldnn()).to_dense(torch.float32)
|
||||
self.assertEqual(y, y_bf16, atol=0.1, rtol=1e-3)
|
||||
else:
|
||||
msg = "mkldnn_max_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim
|
||||
msg = f"mkldnn_max_pool{dim:d}d: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
msg,
|
||||
lambda: max_pool(x_bf16.to_mkldnn()))
|
||||
|
|
@ -883,7 +883,7 @@ class TestMkldnn(TestCase):
|
|||
y_bf16 = avg_pool(x_bf16.to_mkldnn()).to_dense(torch.float)
|
||||
self.assertEqual(y, y_bf16, atol=1e-1, rtol=1e-3)
|
||||
else:
|
||||
msg = "mkldnn_avg_pool%dd: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq" % dim
|
||||
msg = f"mkldnn_avg_pool{dim:d}d: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"
|
||||
self.assertRaisesRegex(RuntimeError,
|
||||
msg,
|
||||
lambda: avg_pool(x_bf16.to_mkldnn()))
|
||||
|
|
|
|||
|
|
@ -30,13 +30,13 @@ def _test_success_single_arg_func(i, arg):
|
|||
|
||||
def _test_exception_single_func(i, arg):
|
||||
if i == arg:
|
||||
raise ValueError("legitimate exception from process %d" % i)
|
||||
raise ValueError(f"legitimate exception from process {i:d}")
|
||||
time.sleep(1.0)
|
||||
|
||||
|
||||
def _test_exception_all_func(i):
|
||||
time.sleep(random.random() / 10)
|
||||
raise ValueError("legitimate exception from process %d" % i)
|
||||
raise ValueError(f"legitimate exception from process {i:d}")
|
||||
|
||||
|
||||
def _test_terminate_signal_func(i):
|
||||
|
|
@ -120,7 +120,7 @@ class _TestMultiProcessing:
|
|||
for i in range(nprocs):
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
"\nValueError: legitimate exception from process %d$" % i,
|
||||
f"\nValueError: legitimate exception from process {i:d}$",
|
||||
):
|
||||
mp.start_processes(_test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method)
|
||||
|
||||
|
|
@ -153,13 +153,13 @@ class _TestMultiProcessing:
|
|||
pid1 = ctx.processes[1].pid
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
"process 0 terminated with exit code %d" % exitcode,
|
||||
f"process 0 terminated with exit code {exitcode:d}",
|
||||
), self.assertLogs(level='WARNING') as logs:
|
||||
while not ctx.join(grace_period=grace_period):
|
||||
pass
|
||||
if grace_period is None:
|
||||
# pid1 is killed by signal.
|
||||
expected_log = "Terminating process %d via signal" % pid1
|
||||
expected_log = f"Terminating process {pid1:d} via signal"
|
||||
self.assertIn(expected_log, logs.records[0].getMessage())
|
||||
else:
|
||||
# pid1 exits on its own.
|
||||
|
|
|
|||
|
|
@ -320,29 +320,29 @@ class TestHash(TestCase):
|
|||
]:
|
||||
for i in range(1, s):
|
||||
assert_equal(
|
||||
hash(st(-(2**i))), hash(-(2**i)), err_msg="%r: -2**%d" % (st, i)
|
||||
hash(st(-(2**i))), hash(-(2**i)), err_msg=f"{st!r}: -2**{i:d}"
|
||||
)
|
||||
assert_equal(
|
||||
hash(st(2 ** (i - 1))),
|
||||
hash(2 ** (i - 1)),
|
||||
err_msg="%r: 2**%d" % (st, i - 1),
|
||||
err_msg=f"{st!r}: 2**{i - 1:d}",
|
||||
)
|
||||
assert_equal(
|
||||
hash(st(2**i - 1)),
|
||||
hash(2**i - 1),
|
||||
err_msg="%r: 2**%d - 1" % (st, i),
|
||||
err_msg=f"{st!r}: 2**{i:d} - 1",
|
||||
)
|
||||
|
||||
i = max(i - 1, 1)
|
||||
assert_equal(
|
||||
hash(ut(2 ** (i - 1))),
|
||||
hash(2 ** (i - 1)),
|
||||
err_msg="%r: 2**%d" % (ut, i - 1),
|
||||
err_msg=f"{ut!r}: 2**{i - 1:d}",
|
||||
)
|
||||
assert_equal(
|
||||
hash(ut(2**i - 1)),
|
||||
hash(2**i - 1),
|
||||
err_msg="%r: 2**%d - 1" % (ut, i),
|
||||
err_msg=f"{ut!r}: 2**{i:d} - 1",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2340,11 +2340,11 @@ class TestMethods(TestCase):
|
|||
# array_less does not seem to work right
|
||||
at(
|
||||
(p[:, :i].T <= p[:, i]).all(),
|
||||
msg="%d: %r <= %r" % (i, p[:, i], p[:, :i].T),
|
||||
msg=f"{i:d}: {p[:, i]!r} <= {p[:, :i].T!r}",
|
||||
)
|
||||
at(
|
||||
(p[:, i + 1 :].T > p[:, i]).all(),
|
||||
msg="%d: %r < %r" % (i, p[:, i], p[:, i + 1 :].T),
|
||||
msg=f"{i:d}: {p[:, i]!r} < {p[:, i + 1 :].T!r}",
|
||||
)
|
||||
aae(
|
||||
p,
|
||||
|
|
@ -2359,11 +2359,11 @@ class TestMethods(TestCase):
|
|||
# array_less does not seem to work right
|
||||
at(
|
||||
(p[:i, :] <= p[i, :]).all(),
|
||||
msg="%d: %r <= %r" % (i, p[i, :], p[:i, :]),
|
||||
msg=f"{i:d}: {p[i, :]!r} <= {p[:i, :]!r}",
|
||||
)
|
||||
at(
|
||||
(p[i + 1 :, :] > p[i, :]).all(),
|
||||
msg="%d: %r < %r" % (i, p[i, :], p[:, i + 1 :]),
|
||||
msg=f"{i:d}: {p[i, :]!r} < {p[:, i + 1 :]!r}",
|
||||
)
|
||||
aae(
|
||||
p,
|
||||
|
|
@ -2387,10 +2387,10 @@ class TestMethods(TestCase):
|
|||
def assert_partitioned(self, d, kth):
|
||||
prev = 0
|
||||
for k in np.sort(kth):
|
||||
assert_array_less(d[prev:k], d[k], err_msg="kth %d" % k)
|
||||
assert_array_less(d[prev:k], d[k], err_msg=f"kth {k:d}")
|
||||
assert_(
|
||||
(d[k:] >= d[k]).all(),
|
||||
msg="kth %d, %r not greater equal %d" % (k, d[k:], d[k]),
|
||||
msg=f"kth {k:d}, {d[k:]!r} not greater equal {d[k]:d}",
|
||||
)
|
||||
prev = k + 1
|
||||
|
||||
|
|
@ -3971,7 +3971,7 @@ class TestIO(TestCase):
|
|||
f.write(b"\0")
|
||||
|
||||
for mode in ["rb", "r+b"]:
|
||||
err_msg = "%d %s" % (size, mode)
|
||||
err_msg = f"{size:d} {mode}"
|
||||
|
||||
with open(tmp_filename, mode) as f:
|
||||
f.read(2)
|
||||
|
|
@ -3988,7 +3988,7 @@ class TestIO(TestCase):
|
|||
]
|
||||
|
||||
for size in sizes:
|
||||
err_msg = "%d" % (size,)
|
||||
err_msg = f"{size:d}"
|
||||
|
||||
with open(tmp_filename, "wb") as f:
|
||||
f.seek(size - 1)
|
||||
|
|
@ -5905,7 +5905,7 @@ class TestPEP3118Dtype(TestCase):
|
|||
if j == 0:
|
||||
s = "bi"
|
||||
else:
|
||||
s = "b%dxi" % j
|
||||
s = f"b{j:d}xi"
|
||||
self._check(
|
||||
"@" + s, {"f0": ("i1", 0), "f1": ("i", align * (1 + j // align))}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -111,8 +111,8 @@ class TestTypes(TestCase):
|
|||
assert_equal(
|
||||
c_scalar.dtype,
|
||||
c_array.dtype,
|
||||
"error with types (%d/'%s' + %d/'%s')"
|
||||
% (k, np.dtype(atype).name, l, np.dtype(btype).name),
|
||||
"error with types "
|
||||
f"({k:d}/'{np.dtype(atype).name}' + {l:d}/'{np.dtype(btype).name}')",
|
||||
)
|
||||
|
||||
def test_type_create(self):
|
||||
|
|
|
|||
8
third_party/generate-cpuinfo-wrappers.py
vendored
8
third_party/generate-cpuinfo-wrappers.py
vendored
|
|
@ -86,9 +86,9 @@ if __name__ == "__main__":
|
|||
print(file=wrapper)
|
||||
|
||||
if not condition:
|
||||
print("#include <%s>" % filename, file=wrapper)
|
||||
print(f"#include <{filename}>", file=wrapper)
|
||||
else:
|
||||
# Include source file only if condition is satisfied
|
||||
print("#if %s" % condition, file=wrapper)
|
||||
print("#include <%s>" % filename, file=wrapper)
|
||||
print("#endif /* %s */" % condition, file=wrapper)
|
||||
print(f"#if {condition}", file=wrapper)
|
||||
print(f"#include <{filename}>", file=wrapper)
|
||||
print(f"#endif /* {condition} */", file=wrapper)
|
||||
|
|
|
|||
14
third_party/generate-xnnpack-wrappers.py
vendored
14
third_party/generate-xnnpack-wrappers.py
vendored
|
|
@ -167,7 +167,7 @@ def gen_wrappers(xnnpack_path):
|
|||
if not os.path.isdir(os.path.dirname(filepath)):
|
||||
os.makedirs(os.path.dirname(filepath))
|
||||
with open(filepath, "w") as wrapper:
|
||||
print("/* {} */".format(BANNER), file=wrapper)
|
||||
print(f"/* {BANNER} */", file=wrapper)
|
||||
print(file=wrapper)
|
||||
|
||||
# Architecture- or platform-dependent preprocessor flags can be
|
||||
|
|
@ -175,12 +175,12 @@ def gen_wrappers(xnnpack_path):
|
|||
# because they are ignored by arc focus & buck project.
|
||||
|
||||
if condition is None:
|
||||
print("#include <%s>" % filename, file=wrapper)
|
||||
print(f"#include <{filename}>", file=wrapper)
|
||||
else:
|
||||
# Include source file only if condition is satisfied
|
||||
print("#if %s" % condition, file=wrapper)
|
||||
print("#include <%s>" % filename, file=wrapper)
|
||||
print("#endif /* %s */" % condition, file=wrapper)
|
||||
print(f"#if {condition}", file=wrapper)
|
||||
print(f"#include <{filename}>", file=wrapper)
|
||||
print(f"#endif /* {condition} */", file=wrapper)
|
||||
|
||||
# update xnnpack_wrapper_defs.bzl file under the same folder
|
||||
with open(os.path.join(os.path.dirname(__file__), "xnnpack_wrapper_defs.bzl"), 'w') as wrapper_defs:
|
||||
|
|
@ -190,7 +190,7 @@ def gen_wrappers(xnnpack_path):
|
|||
for name in WRAPPER_SRC_NAMES:
|
||||
print('\n' + name + ' = [', file=wrapper_defs)
|
||||
for file_name in sources[name]:
|
||||
print(' "xnnpack_wrappers/{}",'.format(file_name), file=wrapper_defs)
|
||||
print(f' "xnnpack_wrappers/{file_name}",', file=wrapper_defs)
|
||||
print(']', file=wrapper_defs)
|
||||
|
||||
# update xnnpack_src_defs.bzl file under the same folder
|
||||
|
|
@ -201,7 +201,7 @@ def gen_wrappers(xnnpack_path):
|
|||
for name in SRC_NAMES:
|
||||
print('\n' + name + ' = [', file=src_defs)
|
||||
for file_name in sources[name]:
|
||||
print(' "XNNPACK/src/{}",'.format(file_name), file=src_defs)
|
||||
print(f' "XNNPACK/src/{file_name}",', file=src_defs)
|
||||
print(']', file=src_defs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def convert_time(seconds: float) -> str:
|
|||
minutes = seconds // 60
|
||||
seconds %= 60
|
||||
|
||||
return "%d:%02d:%02d" % (hour, minutes, seconds)
|
||||
return f"{hour:d}:{minutes:02d}:{seconds:02d}"
|
||||
|
||||
|
||||
def print_time(message: str, start_time: float, summary_time: bool = False) -> None:
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ if try_import_cutlass():
|
|||
"layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined]
|
||||
"element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined]
|
||||
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950
|
||||
"arch": "cutlass::arch::Sm%d" % operation.arch,
|
||||
"arch": f"cutlass::arch::Sm{operation.arch:d}",
|
||||
"tile_shape_m": str(operation.tile_description.tile_shape[0]),
|
||||
"tile_shape_n": str(operation.tile_description.tile_shape[1]),
|
||||
"tile_shape_k": str(operation.tile_description.tile_shape[2]),
|
||||
|
|
|
|||
|
|
@ -403,7 +403,7 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True):
|
|||
usecomplex = False
|
||||
|
||||
def _build_err_msg():
|
||||
header = "Arrays are not almost equal to %d decimals" % decimal
|
||||
header = f"Arrays are not almost equal to {decimal:d} decimals"
|
||||
return build_err_msg([actual, desired], err_msg, verbose=verbose, header=header)
|
||||
|
||||
if usecomplex:
|
||||
|
|
@ -526,7 +526,7 @@ def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True
|
|||
msg = build_err_msg(
|
||||
[actual, desired],
|
||||
err_msg,
|
||||
header="Items are not equal to %d significant digits:" % significant,
|
||||
header=f"Items are not equal to {significant:d} significant digits:",
|
||||
verbose=verbose,
|
||||
)
|
||||
try:
|
||||
|
|
@ -944,7 +944,7 @@ def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True):
|
|||
y,
|
||||
err_msg=err_msg,
|
||||
verbose=verbose,
|
||||
header=("Arrays are not almost equal to %d decimals" % decimal),
|
||||
header=f"Arrays are not almost equal to {decimal:d} decimals",
|
||||
precision=decimal,
|
||||
)
|
||||
|
||||
|
|
@ -1359,10 +1359,10 @@ def assert_array_almost_equal_nulp(x, y, nulp=1):
|
|||
ref = nulp * np.spacing(np.where(ax > ay, ax, ay))
|
||||
if not np.all(np.abs(x - y) <= ref):
|
||||
if np.iscomplexobj(x) or np.iscomplexobj(y):
|
||||
msg = "X and Y are not equal to %d ULP" % nulp
|
||||
msg = f"X and Y are not equal to {nulp:d} ULP"
|
||||
else:
|
||||
max_nulp = np.max(nulp_diff(x, y))
|
||||
msg = "X and Y are not equal to %d ULP (max is %g)" % (nulp, max_nulp)
|
||||
msg = f"X and Y are not equal to {nulp:d} ULP (max is {max_nulp:g})"
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1394,9 +1394,9 @@ def _get_notallclose_msg(
|
|||
)
|
||||
mode = "computed with forward mode " if is_forward_ad else ""
|
||||
return (
|
||||
prefix + "Jacobian %smismatch for output %d with respect to input %d,\n"
|
||||
"numerical:%s\nanalytical:%s\n"
|
||||
% (mode, output_idx, input_idx, numerical, analytical)
|
||||
prefix
|
||||
+ f"Jacobian {mode}mismatch for output {output_idx:d} with respect to input {input_idx:d},\n"
|
||||
f"numerical:{numerical}\nanalytical:{analytical}\n"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -86,18 +86,18 @@ def _retrieve_embedding_parameters(emb_rref):
|
|||
|
||||
def _print_header():
|
||||
_print_cont("\n")
|
||||
_print_cont("%10s" % "")
|
||||
_print_cont(" " * 10)
|
||||
for _ in [50, 75, 90, 95]:
|
||||
_print_cont("%14s%10s" % ("sec/epoch", "epoch/sec"))
|
||||
_print_cont(f"{'sec/epoch':14s}{'epoch/sec':10s}")
|
||||
_print_cont("\n")
|
||||
|
||||
|
||||
def _print_benchmark(prefix, nelem, measurements):
|
||||
measurements = sorted(measurements)
|
||||
_print_cont("%8s:" % prefix)
|
||||
_print_cont(f"{prefix:8s}:")
|
||||
for p in [50, 75, 90, 95]:
|
||||
v = np.percentile(measurements, p)
|
||||
_print_cont(" p%02d: %1.3fs %6d/s" % (p, v, nelem / v))
|
||||
_print_cont(f" p{p:02d}: {v:1.3f}s {nelem / v:6d}/s")
|
||||
_print_cont("\n")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -347,13 +347,13 @@ class _ReaderView(io.IOBase):
|
|||
self.base_stream = base_stream
|
||||
self.seek(0)
|
||||
|
||||
def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
|
||||
if __whence == os.SEEK_SET:
|
||||
__offset = self.offset + __offset
|
||||
elif __whence == os.SEEK_END:
|
||||
__whence = os.SEEK_SET
|
||||
__offset = (self.offset + self.len) - __offset
|
||||
return self.base_stream.seek(__offset, __whence)
|
||||
def seek(self, offset: int, whence: int = os.SEEK_SET, /) -> int:
|
||||
if whence == os.SEEK_SET:
|
||||
offset = self.offset + offset
|
||||
elif whence == os.SEEK_END:
|
||||
whence = os.SEEK_SET
|
||||
offset = (self.offset + self.len) - offset
|
||||
return self.base_stream.seek(offset, whence)
|
||||
|
||||
def tell(self) -> int:
|
||||
return self.base_stream.tell() - self.offset
|
||||
|
|
|
|||
|
|
@ -98,12 +98,12 @@ class _FSDPDeviceHandle:
|
|||
return cast(_FSDPDeviceHandle, torch.mtia)
|
||||
return cls(device)
|
||||
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
def __getattr__(self, name: str, /) -> Any:
|
||||
try:
|
||||
return getattr(self.__backend, __name)
|
||||
return getattr(self.__backend, name)
|
||||
except AttributeError as exc:
|
||||
raise AttributeError(
|
||||
f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{__name}'"
|
||||
f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{name}'"
|
||||
) from exc
|
||||
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ class _UninitializedDeviceHandle(_FSDPDeviceHandle):
|
|||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __getattribute__(self, __name: str) -> Any:
|
||||
def __getattribute__(self, name: str, /) -> Any:
|
||||
raise RuntimeError("Trying to use an uninitialized device handle.")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -68,20 +68,20 @@ class DTensorSpec:
|
|||
self._hash = self._hash_impl()
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
def __eq__(self, other: object, /) -> bool:
|
||||
if not (
|
||||
isinstance(__o, DTensorSpec)
|
||||
and self.mesh == __o.mesh
|
||||
and self.placements == __o.placements
|
||||
isinstance(other, DTensorSpec)
|
||||
and self.mesh == other.mesh
|
||||
and self.placements == other.placements
|
||||
):
|
||||
return False
|
||||
if self.tensor_meta is None or __o.tensor_meta is None:
|
||||
return self.tensor_meta == __o.tensor_meta
|
||||
if self.tensor_meta is None or other.tensor_meta is None:
|
||||
return self.tensor_meta == other.tensor_meta
|
||||
|
||||
return (
|
||||
self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr]
|
||||
and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr]
|
||||
and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr]
|
||||
self.tensor_meta.shape == other.tensor_meta.shape # type: ignore[union-attr]
|
||||
and self.tensor_meta.stride == other.tensor_meta.stride # type: ignore[union-attr]
|
||||
and self.tensor_meta.dtype == other.tensor_meta.dtype # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ class ProcessContext:
|
|||
except ValueError:
|
||||
name = f"<Unknown signal {-exitcode}>"
|
||||
raise ProcessExitedException(
|
||||
"process %d terminated with signal %s" % (error_index, name),
|
||||
f"process {error_index:d} terminated with signal {name}",
|
||||
error_index=error_index,
|
||||
error_pid=failed_process.pid,
|
||||
exit_code=exitcode,
|
||||
|
|
@ -202,7 +202,7 @@ class ProcessContext:
|
|||
)
|
||||
else:
|
||||
raise ProcessExitedException(
|
||||
"process %d terminated with exit code %d" % (error_index, exitcode),
|
||||
f"process {error_index:d} terminated with exit code {exitcode:d}",
|
||||
error_index=error_index,
|
||||
error_pid=failed_process.pid,
|
||||
exit_code=exitcode,
|
||||
|
|
@ -210,7 +210,7 @@ class ProcessContext:
|
|||
|
||||
with open(self.error_files[error_index], "rb") as fh:
|
||||
original_trace = pickle.load(fh)
|
||||
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
|
||||
msg = f"\n\n-- Process {error_index:d} terminated with the following error:\n"
|
||||
msg += original_trace
|
||||
raise ProcessRaisedException(msg, error_index, failed_process.pid)
|
||||
|
||||
|
|
|
|||
|
|
@ -107,12 +107,12 @@ class _ModuleMeta:
|
|||
"""
|
||||
return self._raw_meta
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not isinstance(__value, _ModuleMeta):
|
||||
def __eq__(self, other: object, /) -> bool:
|
||||
if not isinstance(other, _ModuleMeta):
|
||||
return False
|
||||
return (
|
||||
self._module_name == __value._module_name
|
||||
and self._module_class == __value._module_class
|
||||
self._module_name == other._module_name
|
||||
and self._module_class == other._module_class
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
|
|
@ -286,10 +286,10 @@ class _ModuleStackMeta:
|
|||
"""Pushes a module meta to the stack."""
|
||||
self._module_stack.append(module_meta)
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not isinstance(__value, _ModuleStackMeta):
|
||||
def __eq__(self, other: object, /) -> bool:
|
||||
if not isinstance(other, _ModuleStackMeta):
|
||||
return False
|
||||
return self._module_stack == __value._module_stack
|
||||
return self._module_stack == other._module_stack
|
||||
|
||||
@property
|
||||
def raw_meta(self) -> dict[str, tuple[str, type]] | None:
|
||||
|
|
|
|||
|
|
@ -154,15 +154,15 @@ class ElementwiseTypePromotionRule(TypePromotionRule):
|
|||
f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})"
|
||||
)
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not isinstance(__value, ElementwiseTypePromotionRule):
|
||||
def __eq__(self, other: object, /) -> bool:
|
||||
if not isinstance(other, ElementwiseTypePromotionRule):
|
||||
return False
|
||||
return (
|
||||
self.namespace == __value.namespace
|
||||
and self.op_name == __value.op_name
|
||||
and self.promote_args_positions == __value.promote_args_positions
|
||||
and self.promote_kwargs_names == __value.promote_kwargs_names
|
||||
and self.promotion_kind == __value.promotion_kind
|
||||
self.namespace == other.namespace
|
||||
and self.op_name == other.op_name
|
||||
and self.promote_args_positions == other.promote_args_positions
|
||||
and self.promote_kwargs_names == other.promote_kwargs_names
|
||||
and self.promotion_kind == other.promotion_kind
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
|
|
@ -270,13 +270,13 @@ class ReductionTypePromotionRule(TypePromotionRule):
|
|||
def __repr__(self):
|
||||
return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})"
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
if not isinstance(__value, ElementwiseTypePromotionRule):
|
||||
def __eq__(self, other: object, /) -> bool:
|
||||
if not isinstance(other, ElementwiseTypePromotionRule):
|
||||
return False
|
||||
return (
|
||||
self.namespace == __value.namespace
|
||||
and self.op_name == __value.op_name
|
||||
and self.promotion_kind == __value.promotion_kind
|
||||
self.namespace == other.namespace
|
||||
and self.op_name == other.op_name
|
||||
and self.promotion_kind == other.promotion_kind
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
|
|
|
|||
|
|
@ -1532,8 +1532,8 @@ def _set_input_and_output_names(graph, input_names, output_names):
|
|||
return
|
||||
if len(name_list) > len(node_list):
|
||||
raise RuntimeError(
|
||||
"number of %s names provided (%d) exceeded number of %ss (%d)"
|
||||
% (descriptor, len(name_list), descriptor, len(node_list))
|
||||
f"number of {descriptor} names provided ({len(name_list)}) "
|
||||
f"exceeded number of {descriptor}s ({len(node_list)})"
|
||||
)
|
||||
|
||||
# Mark if the output node DebugName is set before.
|
||||
|
|
|
|||
|
|
@ -113,8 +113,8 @@ class PackagePickler(_PyTorchLegacyPickler):
|
|||
)
|
||||
except UnicodeEncodeError as exc:
|
||||
raise PicklingError(
|
||||
"can't pickle global identifier '%s.%s' using "
|
||||
"pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined]
|
||||
f"can't pickle global identifier '{module}.{name}' using "
|
||||
f"pickle protocol {self.proto:d}" # type: ignore[attr-defined]
|
||||
) from exc
|
||||
|
||||
self.memoize(obj) # type: ignore[attr-defined]
|
||||
|
|
|
|||
|
|
@ -4465,7 +4465,7 @@ def retry(ExceptionToCheck, tries=3, delay=3, skip_after_retries=False):
|
|||
try:
|
||||
return f(*args, **kwargs)
|
||||
except ExceptionToCheck as e:
|
||||
msg = "%s, Retrying in %d seconds..." % (str(e), mdelay)
|
||||
msg = f"{e}, Retrying in {mdelay:d} seconds..."
|
||||
print(msg)
|
||||
time.sleep(mdelay)
|
||||
mtries -= 1
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def dist_init(
|
|||
rpc.constants.DEFAULT_SHUTDOWN_TIMEOUT = 60
|
||||
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
name=f"worker{self.rank:d}",
|
||||
backend=self.rpc_backend,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class ShardedTensorTestBase(MultiProcessTestCase):
|
|||
)
|
||||
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
name=f"worker{self.rank:d}",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rpc_backend_options=rpc_backend_options,
|
||||
|
|
|
|||
|
|
@ -454,7 +454,7 @@ def require_backend_is_available(backends):
|
|||
def require_world_size(world_size):
|
||||
if int(os.environ["WORLD_SIZE"]) < world_size:
|
||||
return skip_but_pass_in_sandcastle(
|
||||
"Test requires world size of %d" % world_size
|
||||
f"Test requires world size of {world_size:d}"
|
||||
)
|
||||
return lambda func: func
|
||||
|
||||
|
|
@ -4061,7 +4061,7 @@ class DistributedTest:
|
|||
self.assertGreaterAlmostEqual(
|
||||
float(time.time()),
|
||||
float(expected_time[0]),
|
||||
msg="destination rank: %d, my rank: %d" % (dest, rank)
|
||||
msg=f"destination rank: {dest:d}, my rank: {rank:d}"
|
||||
+ " (if you see this failure, please report in #14554)",
|
||||
)
|
||||
|
||||
|
|
@ -5181,7 +5181,7 @@ class DistributedTest:
|
|||
gradient_as_bucket_view=False,
|
||||
):
|
||||
model = Net()
|
||||
device = devices[0] if devices else torch.device("cuda:%d" % rank)
|
||||
device = devices[0] if devices else torch.device(f"cuda:{rank:d}")
|
||||
ddp_model = DistributedDataParallel(
|
||||
copy.deepcopy(model).to(device),
|
||||
device_ids=device_ids,
|
||||
|
|
@ -5687,7 +5687,7 @@ class DistributedTest:
|
|||
)
|
||||
|
||||
dist.barrier()
|
||||
map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
|
||||
map_location = {"cuda:0": f"cuda:{self.rank:d}"}
|
||||
checkpoint = torch.load(chkpt_file, map_location=map_location)
|
||||
dummy_post_localSGD_opt.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
|
||||
|
|
@ -10133,7 +10133,7 @@ class DistributedTest:
|
|||
)
|
||||
|
||||
dist.barrier()
|
||||
map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
|
||||
map_location = {"cuda:0": f"cuda:{rank:d}"}
|
||||
with self.assertLogs("torch.distributed") as captured:
|
||||
checkpoint = torch.load(chkpt_file, map_location=map_location)
|
||||
|
||||
|
|
|
|||
|
|
@ -98,8 +98,8 @@ class DistOptimizerTest(RpcAgentTestFixture):
|
|||
@dist_init()
|
||||
def test_dist_optim_exception(self):
|
||||
# distributed version
|
||||
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
|
||||
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
|
||||
owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
|
||||
owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
|
||||
|
||||
remote_module1 = rpc.remote(owner1, MyModule)
|
||||
remote_module2 = rpc.remote(owner2, MyModule)
|
||||
|
|
@ -126,8 +126,8 @@ class DistOptimizerTest(RpcAgentTestFixture):
|
|||
@dist_init()
|
||||
def test_dist_optim_exception_on_constructor(self):
|
||||
# distributed version
|
||||
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
|
||||
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
|
||||
owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
|
||||
owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
|
||||
|
||||
remote_module1 = rpc.remote(owner1, MyModule)
|
||||
remote_module2 = rpc.remote(owner2, MyModule)
|
||||
|
|
@ -161,8 +161,8 @@ class DistOptimizerTest(RpcAgentTestFixture):
|
|||
local_optim.step()
|
||||
|
||||
# distributed version
|
||||
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
|
||||
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
|
||||
owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
|
||||
owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
|
||||
|
||||
remote_module1 = rpc.remote(owner1, MyModule)
|
||||
remote_module2 = rpc.remote(owner2, MyModule)
|
||||
|
|
@ -232,8 +232,8 @@ class DistOptimizerTest(RpcAgentTestFixture):
|
|||
local_optim.step()
|
||||
|
||||
# distributed version
|
||||
owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
|
||||
owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
|
||||
owner1 = f"worker{(self.rank + 1) % self.world_size:d}"
|
||||
owner2 = f"worker{(self.rank + 2) % self.world_size:d}"
|
||||
|
||||
remote_module1 = rpc.remote(owner1, MyModule)
|
||||
remote_module2 = rpc.remote(owner2, MyModule, args=(False,))
|
||||
|
|
|
|||
|
|
@ -852,7 +852,7 @@ class RpcTestCommon:
|
|||
def _wait_all_workers(self, f, x):
|
||||
initialize_pg(self.file_init_method, self.rank, self.world_size)
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
name=f"worker{self.rank:d}",
|
||||
backend=self.rpc_backend,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
|
|
@ -874,7 +874,7 @@ class RpcTestCommon:
|
|||
def _wait_all_workers_twice(self, f, x):
|
||||
initialize_pg(self.file_init_method, self.rank, self.world_size)
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
name=f"worker{self.rank:d}",
|
||||
backend=self.rpc_backend,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
|
|
@ -1686,7 +1686,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
|
|||
def test_shutdown_followed_by_rpc(self):
|
||||
# Initialize RPC.
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
name=f"worker{self.rank:d}",
|
||||
backend=self.rpc_backend,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
|
|
@ -3279,7 +3279,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
|
|||
# test that we can start RPC and then immediately locally shutdown
|
||||
# without sending any messages.
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
name=f"worker{self.rank:d}",
|
||||
backend=self.rpc_backend,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
|
|
@ -3321,7 +3321,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
|
|||
# test that if a callee node has gone down, we raise an appropriate
|
||||
# exception instead of just crashing.
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
name=f"worker{self.rank:d}",
|
||||
backend=self.rpc_backend,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
|
|
@ -3368,7 +3368,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
|
|||
def test_local_shutdown_with_rpc(self):
|
||||
# test that we can start RPC, send RPCs, and then run local shutdown.
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
name=f"worker{self.rank:d}",
|
||||
backend=self.rpc_backend,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
|
|
@ -3708,7 +3708,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
|
|||
@dist_init(setup_rpc=False)
|
||||
def test_use_rref_after_shutdown(self):
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
name=f"worker{self.rank:d}",
|
||||
backend=self.rpc_backend,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
|
|
|
|||
|
|
@ -567,37 +567,33 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
|||
# These specializations help with type inference on the lambda passed to this
|
||||
# function
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(
|
||||
__type_or_types_or_pred: Type3[T, S, U],
|
||||
) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
|
||||
def map_only(type_or_types_or_pred: Type[T], /) -> MapOnlyFn[Fn[T, Any]]:
|
||||
...
|
||||
|
||||
|
||||
# This specialization is needed for the implementations below that call
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
||||
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(
|
||||
__type_or_types_or_pred: Callable[[Any], bool],
|
||||
) -> MapOnlyFn[FnAny[Any]]:
|
||||
def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
|
||||
|
||||
def map_only(
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], /
|
||||
) -> MapOnlyFn[FnAny[Any]]:
|
||||
"""
|
||||
Suppose you are writing a tree_map over tensors, leaving everything
|
||||
|
|
@ -617,16 +613,16 @@ def map_only(
|
|||
|
||||
You can also directly use 'tree_map_only'
|
||||
"""
|
||||
if isinstance(__type_or_types_or_pred, (type, tuple)) or (
|
||||
if isinstance(type_or_types_or_pred, (type, tuple)) or (
|
||||
sys.version_info >= (3, 10)
|
||||
and isinstance(__type_or_types_or_pred, types.UnionType)
|
||||
and isinstance(type_or_types_or_pred, types.UnionType)
|
||||
):
|
||||
|
||||
def pred(x: Any) -> bool:
|
||||
return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type]
|
||||
return isinstance(x, type_or_types_or_pred) # type: ignore[arg-type]
|
||||
|
||||
elif callable(__type_or_types_or_pred):
|
||||
pred = __type_or_types_or_pred # type: ignore[assignment]
|
||||
elif callable(type_or_types_or_pred):
|
||||
pred = type_or_types_or_pred # type: ignore[assignment]
|
||||
else:
|
||||
raise TypeError("Argument must be a type, a tuple of types, or a callable.")
|
||||
|
||||
|
|
@ -644,7 +640,8 @@ def map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Type[T],
|
||||
type_or_types_or_pred: Type[T],
|
||||
/,
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -654,7 +651,8 @@ def tree_map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Type2[T, S],
|
||||
type_or_types_or_pred: Type2[T, S],
|
||||
/,
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -664,7 +662,8 @@ def tree_map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Type3[T, S, U],
|
||||
type_or_types_or_pred: Type3[T, S, U],
|
||||
/,
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -674,7 +673,8 @@ def tree_map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Callable[[Any], bool],
|
||||
type_or_types_or_pred: Callable[[Any], bool],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -683,17 +683,19 @@ def tree_map_only(
|
|||
|
||||
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Type[T],
|
||||
type_or_types_or_pred: Type[T],
|
||||
/,
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -703,7 +705,8 @@ def tree_map_only_(
|
|||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Type2[T, S],
|
||||
type_or_types_or_pred: Type2[T, S],
|
||||
/,
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -713,7 +716,8 @@ def tree_map_only_(
|
|||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Type3[T, S, U],
|
||||
type_or_types_or_pred: Type3[T, S, U],
|
||||
/,
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -723,7 +727,8 @@ def tree_map_only_(
|
|||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Callable[[Any], bool],
|
||||
type_or_types_or_pred: Callable[[Any], bool],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -732,12 +737,13 @@ def tree_map_only_(
|
|||
|
||||
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def tree_all(
|
||||
|
|
@ -760,7 +766,8 @@ def tree_any(
|
|||
|
||||
@overload
|
||||
def tree_all_only(
|
||||
__type_or_types: Type[T],
|
||||
type_or_types: Type[T],
|
||||
/,
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -770,7 +777,8 @@ def tree_all_only(
|
|||
|
||||
@overload
|
||||
def tree_all_only(
|
||||
__type_or_types: Type2[T, S],
|
||||
type_or_types: Type2[T, S],
|
||||
/,
|
||||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -780,7 +788,8 @@ def tree_all_only(
|
|||
|
||||
@overload
|
||||
def tree_all_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
type_or_types: Type3[T, S, U],
|
||||
/,
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -789,18 +798,20 @@ def tree_all_only(
|
|||
|
||||
|
||||
def tree_all_only(
|
||||
__type_or_types: TypeAny,
|
||||
type_or_types: TypeAny,
|
||||
/,
|
||||
pred: FnAny[bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
return all(pred(x) for x in flat_args if isinstance(x, type_or_types))
|
||||
|
||||
|
||||
@overload
|
||||
def tree_any_only(
|
||||
__type_or_types: Type[T],
|
||||
type_or_types: Type[T],
|
||||
/,
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -810,7 +821,8 @@ def tree_any_only(
|
|||
|
||||
@overload
|
||||
def tree_any_only(
|
||||
__type_or_types: Type2[T, S],
|
||||
type_or_types: Type2[T, S],
|
||||
/,
|
||||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -820,7 +832,8 @@ def tree_any_only(
|
|||
|
||||
@overload
|
||||
def tree_any_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
type_or_types: Type3[T, S, U],
|
||||
/,
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -829,13 +842,14 @@ def tree_any_only(
|
|||
|
||||
|
||||
def tree_any_only(
|
||||
__type_or_types: TypeAny,
|
||||
type_or_types: TypeAny,
|
||||
/,
|
||||
pred: FnAny[bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
return any(pred(x) for x in flat_args if isinstance(x, type_or_types))
|
||||
|
||||
|
||||
def broadcast_prefix(
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class Freezer:
|
|||
for i in range(0, len(m.bytecode), 16):
|
||||
outfp.write("\n\t")
|
||||
for c in bytes(m.bytecode[i : i + 16]):
|
||||
outfp.write("%d," % c)
|
||||
outfp.write(f"{c:d},")
|
||||
outfp.write("\n};\n")
|
||||
|
||||
def compile_path(self, path: Path, top_package_path: Path):
|
||||
|
|
|
|||
|
|
@ -1038,33 +1038,33 @@ MapOnlyFn = Callable[[T], Callable[[Any], Any]]
|
|||
# These specializations help with type inference on the lambda passed to this
|
||||
# function
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
def map_only(type_or_types_or_pred: Type2[T, S], /) -> MapOnlyFn[Fn2[T, S, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
|
||||
def map_only(type_or_types_or_pred: Type[T], /) -> MapOnlyFn[Fn[T, Any]]:
|
||||
...
|
||||
|
||||
|
||||
# This specialization is needed for the implementations below that call
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
|
||||
def map_only(type_or_types_or_pred: TypeAny, /) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
|
||||
def map_only(type_or_types_or_pred: Callable[[Any], bool], /) -> MapOnlyFn[FnAny[Any]]:
|
||||
...
|
||||
|
||||
|
||||
def map_only(
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], /
|
||||
) -> MapOnlyFn[FnAny[Any]]:
|
||||
"""
|
||||
Suppose you are writing a tree_map over tensors, leaving everything
|
||||
|
|
@ -1084,16 +1084,16 @@ def map_only(
|
|||
|
||||
You can also directly use 'tree_map_only'
|
||||
"""
|
||||
if isinstance(__type_or_types_or_pred, (type, tuple)) or (
|
||||
if isinstance(type_or_types_or_pred, (type, tuple)) or (
|
||||
sys.version_info >= (3, 10)
|
||||
and isinstance(__type_or_types_or_pred, types.UnionType)
|
||||
and isinstance(type_or_types_or_pred, types.UnionType)
|
||||
):
|
||||
|
||||
def pred(x: Any) -> bool:
|
||||
return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type]
|
||||
return isinstance(x, type_or_types_or_pred) # type: ignore[arg-type]
|
||||
|
||||
elif callable(__type_or_types_or_pred):
|
||||
pred = __type_or_types_or_pred # type: ignore[assignment]
|
||||
elif callable(type_or_types_or_pred):
|
||||
pred = type_or_types_or_pred # type: ignore[assignment]
|
||||
else:
|
||||
raise TypeError("Argument must be a type, a tuple of types, or a callable.")
|
||||
|
||||
|
|
@ -1111,7 +1111,8 @@ def map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Type[T],
|
||||
type_or_types_or_pred: Type[T],
|
||||
/,
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1121,7 +1122,8 @@ def tree_map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Type2[T, S],
|
||||
type_or_types_or_pred: Type2[T, S],
|
||||
/,
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1131,7 +1133,8 @@ def tree_map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Type3[T, S, U],
|
||||
type_or_types_or_pred: Type3[T, S, U],
|
||||
/,
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1141,7 +1144,8 @@ def tree_map_only(
|
|||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Callable[[Any], bool],
|
||||
type_or_types_or_pred: Callable[[Any], bool],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1150,17 +1154,19 @@ def tree_map_only(
|
|||
|
||||
|
||||
def tree_map_only(
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Type[T],
|
||||
type_or_types_or_pred: Type[T],
|
||||
/,
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1170,7 +1176,8 @@ def tree_map_only_(
|
|||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Type2[T, S],
|
||||
type_or_types_or_pred: Type2[T, S],
|
||||
/,
|
||||
func: Fn2[T, S, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1180,7 +1187,8 @@ def tree_map_only_(
|
|||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Type3[T, S, U],
|
||||
type_or_types_or_pred: Type3[T, S, U],
|
||||
/,
|
||||
func: Fn3[T, S, U, Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1190,7 +1198,8 @@ def tree_map_only_(
|
|||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Callable[[Any], bool],
|
||||
type_or_types_or_pred: Callable[[Any], bool],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1199,12 +1208,13 @@ def tree_map_only_(
|
|||
|
||||
|
||||
def tree_map_only_(
|
||||
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
|
||||
/,
|
||||
func: FnAny[Any],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> PyTree:
|
||||
return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def tree_all(
|
||||
|
|
@ -1227,7 +1237,8 @@ def tree_any(
|
|||
|
||||
@overload
|
||||
def tree_all_only(
|
||||
__type_or_types: Type[T],
|
||||
type_or_types: Type[T],
|
||||
/,
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1237,7 +1248,8 @@ def tree_all_only(
|
|||
|
||||
@overload
|
||||
def tree_all_only(
|
||||
__type_or_types: Type2[T, S],
|
||||
type_or_types: Type2[T, S],
|
||||
/,
|
||||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1247,7 +1259,8 @@ def tree_all_only(
|
|||
|
||||
@overload
|
||||
def tree_all_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
type_or_types: Type3[T, S, U],
|
||||
/,
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1256,18 +1269,20 @@ def tree_all_only(
|
|||
|
||||
|
||||
def tree_all_only(
|
||||
__type_or_types: TypeAny,
|
||||
type_or_types: TypeAny,
|
||||
/,
|
||||
pred: FnAny[bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
return all(pred(x) for x in flat_args if isinstance(x, type_or_types))
|
||||
|
||||
|
||||
@overload
|
||||
def tree_any_only(
|
||||
__type_or_types: Type[T],
|
||||
type_or_types: Type[T],
|
||||
/,
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1277,7 +1292,8 @@ def tree_any_only(
|
|||
|
||||
@overload
|
||||
def tree_any_only(
|
||||
__type_or_types: Type2[T, S],
|
||||
type_or_types: Type2[T, S],
|
||||
/,
|
||||
pred: Fn2[T, S, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1287,7 +1303,8 @@ def tree_any_only(
|
|||
|
||||
@overload
|
||||
def tree_any_only(
|
||||
__type_or_types: Type3[T, S, U],
|
||||
type_or_types: Type3[T, S, U],
|
||||
/,
|
||||
pred: Fn3[T, S, U, bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
|
|
@ -1296,13 +1313,14 @@ def tree_any_only(
|
|||
|
||||
|
||||
def tree_any_only(
|
||||
__type_or_types: TypeAny,
|
||||
type_or_types: TypeAny,
|
||||
/,
|
||||
pred: FnAny[bool],
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> bool:
|
||||
flat_args = tree_iter(tree, is_leaf=is_leaf)
|
||||
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
|
||||
return any(pred(x) for x in flat_args if isinstance(x, type_or_types))
|
||||
|
||||
|
||||
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
|
||||
|
|
|
|||
|
|
@ -224,8 +224,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"def print_helper(cls, obj):\n",
|
||||
" print(\"DataPipe[{}]\\nInstance type: {}\"\n",
|
||||
" .format(cls.type, obj.type))"
|
||||
" print(f\"DataPipe[{cls.type}]\\nInstance type: {obj.type}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -208,23 +208,21 @@ def main() -> None:
|
|||
write_test_cpp(test_result, options.generated_ops_test_cpp_path)
|
||||
|
||||
print(
|
||||
"\ntotal grouped native ops: %d"
|
||||
% len(gen.get_grouped_native_functions(native_functions))
|
||||
f"\ntotal grouped native ops: {len(gen.get_grouped_native_functions(native_functions)):d}"
|
||||
)
|
||||
|
||||
print("grouped native ops with out variant: %d" % len(native_functions_groups))
|
||||
print(f"grouped native ops with out variant: {len(native_functions_groups):d}")
|
||||
supported_functions_num = sum(len(groups) for groups in supported_functions_groups)
|
||||
print("generated functions groups with out variant: %d" % supported_functions_num)
|
||||
print(f"generated functions groups with out variant: {supported_functions_num:d}")
|
||||
|
||||
print("\nview grouped native ops: %d" % len(native_functions_view_groups))
|
||||
print(f"\nview grouped native ops: {len(native_functions_view_groups):d}")
|
||||
supported_view_functions_num = sum(
|
||||
len(groups) for groups in supported_functions_view_groups
|
||||
)
|
||||
print("generated functions view groups: %d" % supported_view_functions_num)
|
||||
print(f"generated functions view groups: {supported_view_functions_num:d}")
|
||||
|
||||
print(
|
||||
"\noverall generated : %d"
|
||||
% (supported_functions_num + supported_view_functions_num)
|
||||
f"\noverall generated : {supported_functions_num + supported_view_functions_num:d}"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user