[BE] fix remaining flake8 v7 warnings (#159044)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159044
Approved by: https://github.com/Skylion007
ghstack dependencies: #159043
This commit is contained in:
Xuehai Pan 2025-07-25 01:07:16 +08:00 committed by PyTorch MergeBot
parent f903bc475c
commit f5e2de928b
24 changed files with 90 additions and 91 deletions

View File

@ -12,7 +12,7 @@ ignore =
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907
B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907,B908,B910
# these ignores are from flake8-comprehensions; please fix!
C407,
# these ignores are from flake8-logging-format; please fix!

View File

@ -134,7 +134,7 @@ class BenchmarkKernel:
print(
f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}"
)
self.available_backends.remove(backend)
self.available_backends.remove(backend) # noqa: B909
continue
mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref)
perf = Performance(setting, avg_time, mem_bytes)

View File

@ -181,7 +181,6 @@ html_theme_options = {
theme_variables = pytorch_sphinx_theme2.get_theme_variables()
html_context = {
"theme_variables": theme_variables,
"github_url": "https://github.com",
"github_user": "pytorch",
"github_repo": "pytorch",
@ -189,7 +188,7 @@ html_context = {
"github_version": "main",
"pytorch_project": "docs",
"doc_path": "docs/source",
"theme_variables": theme_variables, # noqa: F601
"theme_variables": theme_variables,
# library links are defined in
# pytorch_sphinx_theme2/pytorch_sphinx_theme2/links.json
"library_links": theme_variables.get("library_links", []),

View File

@ -1165,27 +1165,23 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
# Increased tolerances are needed to pass when using TF32
# See: https://github.com/pytorch/pytorch/issues/67764
(
torch.testing.assert_close(
local_loss.cpu(),
ddp_loss.cpu(),
rtol=1e-03,
atol=1e-08,
),
"Losses differ between local optimizer and ZeRO",
torch.testing.assert_close(
local_loss.cpu(),
ddp_loss.cpu(),
rtol=1e-03,
atol=1e-08,
msg="Losses differ between local optimizer and ZeRO",
)
for local_p, ddp_p in zip(
local_model.parameters(), ddp_model.parameters()
):
(
torch.testing.assert_close(
local_p.cpu(),
ddp_p.cpu(),
rtol=1e-03,
atol=1e-04,
),
"Models differ after a step",
torch.testing.assert_close(
local_p.cpu(),
ddp_p.cpu(),
rtol=1e-03,
atol=1e-04,
msg="Models differ after a step",
)
@skipIfHpu

View File

@ -6369,7 +6369,7 @@ class TestLazyLogitsInitialization(DistributionsTestCase):
except NotImplementedError:
pass
self.assertNotIn("probs", dist.__dict__, msg=message)
dist.batch_shape, dist.event_shape
_ = (dist.batch_shape, dist.event_shape)
self.assertNotIn("probs", dist.__dict__, msg=message)
def test_lazy_probs_initialization(self):
@ -6386,7 +6386,7 @@ class TestLazyLogitsInitialization(DistributionsTestCase):
except NotImplementedError:
pass
self.assertNotIn("logits", dist.__dict__, msg=message)
dist.batch_shape, dist.event_shape
_ = (dist.batch_shape, dist.event_shape)
self.assertNotIn("logits", dist.__dict__, msg=message)

View File

@ -991,7 +991,7 @@ class TestExceptionPropagation(TestCase):
s = OrderedSet([1, 2, 3])
try:
for i in s:
s.update([4])
s.update([4]) # noqa: B909
except RuntimeError:
pass
else:

View File

@ -8863,7 +8863,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
def impl(a, b, c, d, e=2):
(a.add_(b[0] * c * e),)
a.add_(b[0] * c * e)
if d is not None:
d.add_(b[1])
@ -8936,7 +8936,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
def impl(a, b, c, d, e=2):
(a.add_(b[0] * c * e),)
a.add_(b[0] * c * e)
if d is not None:
d.add_(b[1])
return b[0] + b[1]

View File

@ -246,7 +246,7 @@ def duplicate_opinfo_for_prims(
new_opinfo = copy.deepcopy(opinfo)
new_opinfo.name = new_name
new_opinfo.op = getattr(torch.ops.prims, prims_name)
opinfos.append(new_opinfo)
opinfos.append(new_opinfo) # noqa: B909
return
raise RuntimeError(f"OpInfo '{name}' not found in the database.")

View File

@ -1,18 +1,5 @@
# Owner(s): ["oncall: profiler"]
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
# This causes an issue in the multithreading test because we check all events
# in that test with their tids. The events that correspond to these lingering
# threads all have TID of (uint64_t)(-1) which is invalid.
# The work around is turnning off monitoring thread when tqdm is loaded.
# Since these are unit tests, it is safe to turn off monitor thread.
try:
import tqdm
tqdm.tqdm.monitor_interval = 0
except ImportError:
pass
import json
import os
import tempfile
@ -52,6 +39,19 @@ from torch.testing._internal.common_utils import (
from torch.utils._triton import has_triton
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
# This causes an issue in the multithreading test because we check all events
# in that test with their tids. The events that correspond to these lingering
# threads all have TID of (uint64_t)(-1) which is invalid.
# The work around is turnning off monitoring thread when tqdm is loaded.
# Since these are unit tests, it is safe to turn off monitor thread.
try:
import tqdm
tqdm.tqdm.monitor_interval = 0
except ImportError:
pass
Json = dict[str, Any]

View File

@ -1,19 +1,6 @@
# Owner(s): ["oncall: profiler"]
# ruff: noqa: F841
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
# This causes an issue in the multithreading test because we check all events
# in that test with their tids. The events that correspond to these lingering
# threads all have TID of (uint64_t)(-1) which is invalid.
# The work around is turnning off monitoring thread when tqdm is loaded.
# Since these are unit tests, it is safe to turn off monitor thread.
try:
import tqdm
tqdm.tqdm.monitor_interval = 0
except ImportError:
None
from typing import Any
import torch
@ -29,6 +16,19 @@ from torch.profiler import kineto_available, record_function
from torch.testing._internal.common_utils import run_tests, TestCase
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
# This causes an issue in the multithreading test because we check all events
# in that test with their tids. The events that correspond to these lingering
# threads all have TID of (uint64_t)(-1) which is invalid.
# The work around is turnning off monitoring thread when tqdm is loaded.
# Since these are unit tests, it is safe to turn off monitor thread.
try:
import tqdm
tqdm.tqdm.monitor_interval = 0
except ImportError:
pass
Json = dict[str, Any]

View File

@ -1,18 +1,5 @@
# Owner(s): ["oncall: profiler"]
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
# This causes an issue in the multithreading test because we check all events
# in that test with their tids. The events that correspond to these lingering
# threads all have TID of (uint64_t)(-1) which is invalid.
# The work around is turnning off monitoring thread when tqdm is loaded.
# Since these are unit tests, it is safe to turn off monitor thread.
try:
import tqdm
tqdm.tqdm.monitor_interval = 0
except ImportError:
None
import gc
import re
import textwrap
@ -24,14 +11,25 @@ import torch
import torch.nn as nn
import torch.optim
import torch.utils.data
from torch._C._profiler import _TensorMetadata
from torch._C._profiler import _ExtraFields_PyCall, _TensorMetadata
from torch.profiler import _utils, profile
from torch.testing._internal.common_utils import run_tests, TestCase
Json = dict[str, Any]
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
# This causes an issue in the multithreading test because we check all events
# in that test with their tids. The events that correspond to these lingering
# threads all have TID of (uint64_t)(-1) which is invalid.
# The work around is turnning off monitoring thread when tqdm is loaded.
# Since these are unit tests, it is safe to turn off monitor thread.
try:
import tqdm
from torch._C._profiler import _ExtraFields_PyCall
tqdm.tqdm.monitor_interval = 0
except ImportError:
pass
Json = dict[str, Any]
def find_node_with_name(nodes, name):

View File

@ -827,7 +827,7 @@ class TestFuseFx(QuantizationTestCase):
named_modules = dict(m.named_modules())
for node in m.graph.nodes:
if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d:
self.assertTrue(len(node.args) == 2), "Expecting the fused op to have two arguments"
self.assertTrue(len(node.args) == 2, msg="Expecting the fused op to have two arguments")
def test_fusion_pattern_with_matchallnode(self):
"""This test tests that the node matched by MatchAllNode will be regared as an input

View File

@ -8097,7 +8097,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
with self.assertWarnsOnceRegex(
UserWarning, f"This overload of {func}_ is deprecated"):
getattr(out_tensor, func + "_")(1, b1, b2)
self.assertEqual(out_tensor, ref * 2),
self.assertEqual(out_tensor, ref * 2)
getattr(res3, func + "_")(b1, b2, beta=1)
self.assertEqual(out_tensor, res3)
@ -8113,7 +8113,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2))
res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=.5)
self.assertEqual(res4, ref * 3),
self.assertEqual(res4, ref * 3)
nan = torch.full_like(out_tensor, math.nan)
res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1)

View File

@ -201,7 +201,7 @@ class _TestMultiProcessing:
try:
os.kill(pid, 0)
except ProcessLookupError:
pids.remove(pid)
pids.remove(pid) # noqa: B909
break
# This assert fails if any nested child process is still

View File

@ -714,16 +714,18 @@ class TestSparseSemiStructuredTraining(TestCase):
max_diff = (ref_gemm - pack_gemm).abs().argmax()
torch.testing.assert_close(
ref_gemm, pack_gemm,
**atol_rtol_kw[dtype]
), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
**atol_rtol_kw[dtype],
msg=f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})",
)
# Test A.t@B
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
max_diff = (ref_gemm - pack_gemm).abs().argmax()
torch.testing.assert_close(
ref_gemm, pack_gemm,
**atol_rtol_kw[dtype]
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
**atol_rtol_kw[dtype],
msg=f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})",
)
@training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")

View File

@ -807,10 +807,8 @@ class DeterministicAlgorithmsVariable(ContextWrappingVariable):
def _call_func(self, tx: "InstructionTranslator", values):
assert len(values) == 1
value = values[0]
(
tx.output.create_node(
"call_function", torch._C._set_deterministic_algorithms, (value,), {}
),
tx.output.create_node(
"call_function", torch._C._set_deterministic_algorithms, (value,), {}
)
torch._C._set_deterministic_algorithms(value)

View File

@ -545,7 +545,7 @@ class _TargetExpr(PatternExpr):
fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
for fn in fns:
if isinstance(fn, torch._ops.OpOverloadPacket):
fns.extend(getattr(fn, overload) for overload in fn.overloads())
fns.extend(getattr(fn, overload) for overload in fn.overloads()) # noqa: B909
self.fns = fns
self.fns_set = OrderedSet(fns)

View File

@ -574,7 +574,7 @@ class CachingAutotuner(KernelInterface):
assert hasattr(self, "_reload_kernel")
assert callable(self._reload_kernel)
self.fn = self._reload_kernel().fn
self.compile_results.append(self._precompile_config(new_config))
self.compile_results.append(self._precompile_config(new_config)) # noqa: B909
self._make_launchers()

View File

@ -3937,7 +3937,7 @@ class Scheduler:
if remaining:
for rd in remaining:
if self.fusable_read_and_write(rd, cd):
remaining.remove(rd)
remaining.remove(rd) # noqa: B909
remaining_deps = OrderedSet(
dep.name

View File

@ -408,7 +408,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
return True
if key in self.keys:
True
return True
unflattened_keys: list[str] = []
planner_data = metadata.planner_data.get(key)

View File

@ -64,7 +64,7 @@ def infer_symbol_values(
for right_var in right_vars:
if sp.sympify(right_var) == sp.sympify("s0"):
right_equation = sp.cancel(right_equation / right_var)
right_vars.remove(right_var)
right_vars.remove(right_var) # noqa: B909
var = right_vars[0]
idx = symbol_idx_dict[str(var)]

View File

@ -394,7 +394,7 @@ def split_module(
root_partition = root_partitions.pop()
sorted_partitions.append(root_partition)
for dependent in partitions[root_partition].dependents:
partitions[dependent].dependencies.pop(root_partition)
partitions[dependent].dependencies.pop(root_partition) # noqa: B909
if not partitions[dependent].dependencies:
root_partitions.append(dependent)
if len(sorted_partitions) != len(partitions):

View File

@ -11847,8 +11847,11 @@ op_db: list[OpInfo] = [
safe_val=2)),
BinaryUfuncInfo('add',
# NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
else np.add(input, np.multiply(alpha, other)),
ref=lambda input, other, *, alpha=1: (
np.add(input, other)
if alpha == 1
else np.add(input, np.multiply(alpha, other))
),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
torch.float16, torch.chalf),
dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
@ -20498,8 +20501,11 @@ op_db: list[OpInfo] = [
'jiterator_binary',
op=torch.cuda.jiterator._create_jit_fn(
"template <typename T> T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1),
ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
else np.add(input, np.multiply(alpha, other)),
ref=lambda input, other, *, alpha=1: (
np.add(input, other)
if alpha == 1
else np.add(input, np.multiply(alpha, other))
),
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14),
supports_out=False,

View File

@ -336,7 +336,7 @@ def cuda_allocation_context():
def to_dot(nodes):
lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;']
for i, n in enumerate(nodes):
lines.append(f'{i} [label={escape(n.label)}, color={ "red" if n.root else "black"}];')
lines.append(f'{i} [label={escape(n.label)}, color={"red" if n.root else "black"}];')
for i, f in enumerate(nodes):
for label, j in f.referrents: