mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] fix remaining flake8 v7 warnings (#159044)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159044 Approved by: https://github.com/Skylion007 ghstack dependencies: #159043
This commit is contained in:
parent
f903bc475c
commit
f5e2de928b
2
.flake8
2
.flake8
|
|
@ -12,7 +12,7 @@ ignore =
|
|||
# to line this up with executable bit
|
||||
EXE001,
|
||||
# these ignores are from flake8-bugbear; please fix!
|
||||
B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907
|
||||
B007,B008,B017,B019,B023,B028,B903,B904,B905,B906,B907,B908,B910
|
||||
# these ignores are from flake8-comprehensions; please fix!
|
||||
C407,
|
||||
# these ignores are from flake8-logging-format; please fix!
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ class BenchmarkKernel:
|
|||
print(
|
||||
f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}"
|
||||
)
|
||||
self.available_backends.remove(backend)
|
||||
self.available_backends.remove(backend) # noqa: B909
|
||||
continue
|
||||
mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref)
|
||||
perf = Performance(setting, avg_time, mem_bytes)
|
||||
|
|
|
|||
|
|
@ -181,7 +181,6 @@ html_theme_options = {
|
|||
|
||||
theme_variables = pytorch_sphinx_theme2.get_theme_variables()
|
||||
html_context = {
|
||||
"theme_variables": theme_variables,
|
||||
"github_url": "https://github.com",
|
||||
"github_user": "pytorch",
|
||||
"github_repo": "pytorch",
|
||||
|
|
@ -189,7 +188,7 @@ html_context = {
|
|||
"github_version": "main",
|
||||
"pytorch_project": "docs",
|
||||
"doc_path": "docs/source",
|
||||
"theme_variables": theme_variables, # noqa: F601
|
||||
"theme_variables": theme_variables,
|
||||
# library links are defined in
|
||||
# pytorch_sphinx_theme2/pytorch_sphinx_theme2/links.json
|
||||
"library_links": theme_variables.get("library_links", []),
|
||||
|
|
|
|||
|
|
@ -1165,27 +1165,23 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||
|
||||
# Increased tolerances are needed to pass when using TF32
|
||||
# See: https://github.com/pytorch/pytorch/issues/67764
|
||||
(
|
||||
torch.testing.assert_close(
|
||||
local_loss.cpu(),
|
||||
ddp_loss.cpu(),
|
||||
rtol=1e-03,
|
||||
atol=1e-08,
|
||||
),
|
||||
"Losses differ between local optimizer and ZeRO",
|
||||
torch.testing.assert_close(
|
||||
local_loss.cpu(),
|
||||
ddp_loss.cpu(),
|
||||
rtol=1e-03,
|
||||
atol=1e-08,
|
||||
msg="Losses differ between local optimizer and ZeRO",
|
||||
)
|
||||
|
||||
for local_p, ddp_p in zip(
|
||||
local_model.parameters(), ddp_model.parameters()
|
||||
):
|
||||
(
|
||||
torch.testing.assert_close(
|
||||
local_p.cpu(),
|
||||
ddp_p.cpu(),
|
||||
rtol=1e-03,
|
||||
atol=1e-04,
|
||||
),
|
||||
"Models differ after a step",
|
||||
torch.testing.assert_close(
|
||||
local_p.cpu(),
|
||||
ddp_p.cpu(),
|
||||
rtol=1e-03,
|
||||
atol=1e-04,
|
||||
msg="Models differ after a step",
|
||||
)
|
||||
|
||||
@skipIfHpu
|
||||
|
|
|
|||
|
|
@ -6369,7 +6369,7 @@ class TestLazyLogitsInitialization(DistributionsTestCase):
|
|||
except NotImplementedError:
|
||||
pass
|
||||
self.assertNotIn("probs", dist.__dict__, msg=message)
|
||||
dist.batch_shape, dist.event_shape
|
||||
_ = (dist.batch_shape, dist.event_shape)
|
||||
self.assertNotIn("probs", dist.__dict__, msg=message)
|
||||
|
||||
def test_lazy_probs_initialization(self):
|
||||
|
|
@ -6386,7 +6386,7 @@ class TestLazyLogitsInitialization(DistributionsTestCase):
|
|||
except NotImplementedError:
|
||||
pass
|
||||
self.assertNotIn("logits", dist.__dict__, msg=message)
|
||||
dist.batch_shape, dist.event_shape
|
||||
_ = (dist.batch_shape, dist.event_shape)
|
||||
self.assertNotIn("logits", dist.__dict__, msg=message)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -991,7 +991,7 @@ class TestExceptionPropagation(TestCase):
|
|||
s = OrderedSet([1, 2, 3])
|
||||
try:
|
||||
for i in s:
|
||||
s.update([4])
|
||||
s.update([4]) # noqa: B909
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -8863,7 +8863,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
|
||||
|
||||
def impl(a, b, c, d, e=2):
|
||||
(a.add_(b[0] * c * e),)
|
||||
a.add_(b[0] * c * e)
|
||||
if d is not None:
|
||||
d.add_(b[1])
|
||||
|
||||
|
|
@ -8936,7 +8936,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
|
||||
|
||||
def impl(a, b, c, d, e=2):
|
||||
(a.add_(b[0] * c * e),)
|
||||
a.add_(b[0] * c * e)
|
||||
if d is not None:
|
||||
d.add_(b[1])
|
||||
return b[0] + b[1]
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ def duplicate_opinfo_for_prims(
|
|||
new_opinfo = copy.deepcopy(opinfo)
|
||||
new_opinfo.name = new_name
|
||||
new_opinfo.op = getattr(torch.ops.prims, prims_name)
|
||||
opinfos.append(new_opinfo)
|
||||
opinfos.append(new_opinfo) # noqa: B909
|
||||
return
|
||||
raise RuntimeError(f"OpInfo '{name}' not found in the database.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,5 @@
|
|||
# Owner(s): ["oncall: profiler"]
|
||||
|
||||
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
|
||||
# This causes an issue in the multithreading test because we check all events
|
||||
# in that test with their tids. The events that correspond to these lingering
|
||||
# threads all have TID of (uint64_t)(-1) which is invalid.
|
||||
# The work around is turnning off monitoring thread when tqdm is loaded.
|
||||
# Since these are unit tests, it is safe to turn off monitor thread.
|
||||
try:
|
||||
import tqdm
|
||||
|
||||
tqdm.tqdm.monitor_interval = 0
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
|
@ -52,6 +39,19 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
|
||||
# This causes an issue in the multithreading test because we check all events
|
||||
# in that test with their tids. The events that correspond to these lingering
|
||||
# threads all have TID of (uint64_t)(-1) which is invalid.
|
||||
# The work around is turnning off monitoring thread when tqdm is loaded.
|
||||
# Since these are unit tests, it is safe to turn off monitor thread.
|
||||
try:
|
||||
import tqdm
|
||||
|
||||
tqdm.tqdm.monitor_interval = 0
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
Json = dict[str, Any]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,6 @@
|
|||
# Owner(s): ["oncall: profiler"]
|
||||
# ruff: noqa: F841
|
||||
|
||||
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
|
||||
# This causes an issue in the multithreading test because we check all events
|
||||
# in that test with their tids. The events that correspond to these lingering
|
||||
# threads all have TID of (uint64_t)(-1) which is invalid.
|
||||
# The work around is turnning off monitoring thread when tqdm is loaded.
|
||||
# Since these are unit tests, it is safe to turn off monitor thread.
|
||||
try:
|
||||
import tqdm
|
||||
|
||||
tqdm.tqdm.monitor_interval = 0
|
||||
except ImportError:
|
||||
None
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
|
@ -29,6 +16,19 @@ from torch.profiler import kineto_available, record_function
|
|||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
|
||||
# This causes an issue in the multithreading test because we check all events
|
||||
# in that test with their tids. The events that correspond to these lingering
|
||||
# threads all have TID of (uint64_t)(-1) which is invalid.
|
||||
# The work around is turnning off monitoring thread when tqdm is loaded.
|
||||
# Since these are unit tests, it is safe to turn off monitor thread.
|
||||
try:
|
||||
import tqdm
|
||||
|
||||
tqdm.tqdm.monitor_interval = 0
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
Json = dict[str, Any]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,5 @@
|
|||
# Owner(s): ["oncall: profiler"]
|
||||
|
||||
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
|
||||
# This causes an issue in the multithreading test because we check all events
|
||||
# in that test with their tids. The events that correspond to these lingering
|
||||
# threads all have TID of (uint64_t)(-1) which is invalid.
|
||||
# The work around is turnning off monitoring thread when tqdm is loaded.
|
||||
# Since these are unit tests, it is safe to turn off monitor thread.
|
||||
try:
|
||||
import tqdm
|
||||
|
||||
tqdm.tqdm.monitor_interval = 0
|
||||
except ImportError:
|
||||
None
|
||||
|
||||
import gc
|
||||
import re
|
||||
import textwrap
|
||||
|
|
@ -24,14 +11,25 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
from torch._C._profiler import _TensorMetadata
|
||||
from torch._C._profiler import _ExtraFields_PyCall, _TensorMetadata
|
||||
from torch.profiler import _utils, profile
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
Json = dict[str, Any]
|
||||
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
|
||||
# This causes an issue in the multithreading test because we check all events
|
||||
# in that test with their tids. The events that correspond to these lingering
|
||||
# threads all have TID of (uint64_t)(-1) which is invalid.
|
||||
# The work around is turnning off monitoring thread when tqdm is loaded.
|
||||
# Since these are unit tests, it is safe to turn off monitor thread.
|
||||
try:
|
||||
import tqdm
|
||||
|
||||
from torch._C._profiler import _ExtraFields_PyCall
|
||||
tqdm.tqdm.monitor_interval = 0
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
Json = dict[str, Any]
|
||||
|
||||
|
||||
def find_node_with_name(nodes, name):
|
||||
|
|
|
|||
|
|
@ -827,7 +827,7 @@ class TestFuseFx(QuantizationTestCase):
|
|||
named_modules = dict(m.named_modules())
|
||||
for node in m.graph.nodes:
|
||||
if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d:
|
||||
self.assertTrue(len(node.args) == 2), "Expecting the fused op to have two arguments"
|
||||
self.assertTrue(len(node.args) == 2, msg="Expecting the fused op to have two arguments")
|
||||
|
||||
def test_fusion_pattern_with_matchallnode(self):
|
||||
"""This test tests that the node matched by MatchAllNode will be regared as an input
|
||||
|
|
|
|||
|
|
@ -8097,7 +8097,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||
with self.assertWarnsOnceRegex(
|
||||
UserWarning, f"This overload of {func}_ is deprecated"):
|
||||
getattr(out_tensor, func + "_")(1, b1, b2)
|
||||
self.assertEqual(out_tensor, ref * 2),
|
||||
self.assertEqual(out_tensor, ref * 2)
|
||||
getattr(res3, func + "_")(b1, b2, beta=1)
|
||||
self.assertEqual(out_tensor, res3)
|
||||
|
||||
|
|
@ -8113,7 +8113,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||
self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2))
|
||||
|
||||
res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=.5)
|
||||
self.assertEqual(res4, ref * 3),
|
||||
self.assertEqual(res4, ref * 3)
|
||||
|
||||
nan = torch.full_like(out_tensor, math.nan)
|
||||
res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1)
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ class _TestMultiProcessing:
|
|||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
pids.remove(pid)
|
||||
pids.remove(pid) # noqa: B909
|
||||
break
|
||||
|
||||
# This assert fails if any nested child process is still
|
||||
|
|
|
|||
|
|
@ -714,16 +714,18 @@ class TestSparseSemiStructuredTraining(TestCase):
|
|||
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
||||
torch.testing.assert_close(
|
||||
ref_gemm, pack_gemm,
|
||||
**atol_rtol_kw[dtype]
|
||||
), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
||||
**atol_rtol_kw[dtype],
|
||||
msg=f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})",
|
||||
)
|
||||
# Test A.t@B
|
||||
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
|
||||
max_diff = (ref_gemm - pack_gemm).abs().argmax()
|
||||
|
||||
torch.testing.assert_close(
|
||||
ref_gemm, pack_gemm,
|
||||
**atol_rtol_kw[dtype]
|
||||
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
||||
**atol_rtol_kw[dtype],
|
||||
msg=f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})",
|
||||
)
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
|
|
|
|||
|
|
@ -807,10 +807,8 @@ class DeterministicAlgorithmsVariable(ContextWrappingVariable):
|
|||
def _call_func(self, tx: "InstructionTranslator", values):
|
||||
assert len(values) == 1
|
||||
value = values[0]
|
||||
(
|
||||
tx.output.create_node(
|
||||
"call_function", torch._C._set_deterministic_algorithms, (value,), {}
|
||||
),
|
||||
tx.output.create_node(
|
||||
"call_function", torch._C._set_deterministic_algorithms, (value,), {}
|
||||
)
|
||||
torch._C._set_deterministic_algorithms(value)
|
||||
|
||||
|
|
|
|||
|
|
@ -545,7 +545,7 @@ class _TargetExpr(PatternExpr):
|
|||
fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns)
|
||||
for fn in fns:
|
||||
if isinstance(fn, torch._ops.OpOverloadPacket):
|
||||
fns.extend(getattr(fn, overload) for overload in fn.overloads())
|
||||
fns.extend(getattr(fn, overload) for overload in fn.overloads()) # noqa: B909
|
||||
|
||||
self.fns = fns
|
||||
self.fns_set = OrderedSet(fns)
|
||||
|
|
|
|||
|
|
@ -574,7 +574,7 @@ class CachingAutotuner(KernelInterface):
|
|||
assert hasattr(self, "_reload_kernel")
|
||||
assert callable(self._reload_kernel)
|
||||
self.fn = self._reload_kernel().fn
|
||||
self.compile_results.append(self._precompile_config(new_config))
|
||||
self.compile_results.append(self._precompile_config(new_config)) # noqa: B909
|
||||
|
||||
self._make_launchers()
|
||||
|
||||
|
|
|
|||
|
|
@ -3937,7 +3937,7 @@ class Scheduler:
|
|||
if remaining:
|
||||
for rd in remaining:
|
||||
if self.fusable_read_and_write(rd, cd):
|
||||
remaining.remove(rd)
|
||||
remaining.remove(rd) # noqa: B909
|
||||
|
||||
remaining_deps = OrderedSet(
|
||||
dep.name
|
||||
|
|
|
|||
|
|
@ -408,7 +408,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
|
|||
return True
|
||||
|
||||
if key in self.keys:
|
||||
True
|
||||
return True
|
||||
|
||||
unflattened_keys: list[str] = []
|
||||
planner_data = metadata.planner_data.get(key)
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ def infer_symbol_values(
|
|||
for right_var in right_vars:
|
||||
if sp.sympify(right_var) == sp.sympify("s0"):
|
||||
right_equation = sp.cancel(right_equation / right_var)
|
||||
right_vars.remove(right_var)
|
||||
right_vars.remove(right_var) # noqa: B909
|
||||
|
||||
var = right_vars[0]
|
||||
idx = symbol_idx_dict[str(var)]
|
||||
|
|
|
|||
|
|
@ -394,7 +394,7 @@ def split_module(
|
|||
root_partition = root_partitions.pop()
|
||||
sorted_partitions.append(root_partition)
|
||||
for dependent in partitions[root_partition].dependents:
|
||||
partitions[dependent].dependencies.pop(root_partition)
|
||||
partitions[dependent].dependencies.pop(root_partition) # noqa: B909
|
||||
if not partitions[dependent].dependencies:
|
||||
root_partitions.append(dependent)
|
||||
if len(sorted_partitions) != len(partitions):
|
||||
|
|
|
|||
|
|
@ -11847,8 +11847,11 @@ op_db: list[OpInfo] = [
|
|||
safe_val=2)),
|
||||
BinaryUfuncInfo('add',
|
||||
# NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate
|
||||
ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
|
||||
else np.add(input, np.multiply(alpha, other)),
|
||||
ref=lambda input, other, *, alpha=1: (
|
||||
np.add(input, other)
|
||||
if alpha == 1
|
||||
else np.add(input, np.multiply(alpha, other))
|
||||
),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
|
||||
torch.float16, torch.chalf),
|
||||
dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32),
|
||||
|
|
@ -20498,8 +20501,11 @@ op_db: list[OpInfo] = [
|
|||
'jiterator_binary',
|
||||
op=torch.cuda.jiterator._create_jit_fn(
|
||||
"template <typename T> T binary(T x, T y, T alpha) { return x + alpha * y; }", alpha=1),
|
||||
ref=lambda input, other, *, alpha=1: np.add(input, other) if alpha == 1 \
|
||||
else np.add(input, np.multiply(alpha, other)),
|
||||
ref=lambda input, other, *, alpha=1: (
|
||||
np.add(input, other)
|
||||
if alpha == 1
|
||||
else np.add(input, np.multiply(alpha, other))
|
||||
),
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
|
||||
sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-3.14),
|
||||
supports_out=False,
|
||||
|
|
|
|||
|
|
@ -336,7 +336,7 @@ def cuda_allocation_context():
|
|||
def to_dot(nodes):
|
||||
lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;']
|
||||
for i, n in enumerate(nodes):
|
||||
lines.append(f'{i} [label={escape(n.label)}, color={ "red" if n.root else "black"}];')
|
||||
lines.append(f'{i} [label={escape(n.label)}, color={"red" if n.root else "black"}];')
|
||||
|
||||
for i, f in enumerate(nodes):
|
||||
for label, j in f.referrents:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user