[BE] Fix all B022 useless-contextlib-suppress (#100335)

No arguments passed to contextlib.suppress. No exceptions will be suppressed and therefore this context manager is redundant

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100335
Approved by: https://github.com/Skylion007
This commit is contained in:
Justin Chu 2023-04-30 08:57:15 -07:00 committed by PyTorch MergeBot
parent d7bdfd3454
commit 01abbfbaae
18 changed files with 38 additions and 39 deletions

View File

@ -30,7 +30,6 @@ ignore = [
"B007", "B008", "B017",
"B018", # Useless expression
"B019", "B020",
"B022", # Allow empty context manager
"B023", "B024", "B026",
"B028", # No explicit `stacklevel` keyword argument found
"B027", "B904", "B905",

View File

@ -250,7 +250,7 @@ class TestJoin(MultiProcessTestCase):
with self.assertRaisesRegex(
RuntimeError,
expected_msg
) if throw_on_early_termination else contextlib.suppress():
) if throw_on_early_termination else contextlib.nullcontext():
with Join(
allreducers,
enable=enable,

View File

@ -257,7 +257,7 @@ class TestFSDPCheckpoint(FSDPTest):
offload_ctx = (
get_patched_save_on_cpu()(pin_memory=True)
if offload_activations
else contextlib.suppress()
else contextlib.nullcontext()
)
with offload_ctx:
out = checkpoint(m, inp, use_reentrant=True)

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: distributed"]
import sys
from contextlib import suppress
from contextlib import nullcontext
from enum import auto, Enum
from typing import Optional
from unittest.mock import patch
@ -76,7 +76,7 @@ class TestCommunication(FSDPTest):
def _run_iter(self, fsdp_model, batch, use_no_sync: bool):
"""Runs an iteration inside or outside the ``no_sync()`` context."""
context = fsdp_model.no_sync() if use_no_sync else suppress()
context = fsdp_model.no_sync() if use_no_sync else nullcontext()
with context:
output = fsdp_model(*batch)
loss = fsdp_model.module.get_loss(batch, output)

View File

@ -2,7 +2,7 @@
import sys
import warnings
from contextlib import suppress
from contextlib import nullcontext
import torch
from torch import distributed as dist
@ -159,7 +159,7 @@ class TestFSDPExecOrder(FSDPTest):
expected_regex=regex,
)
if self.rank != 0
else suppress()
else nullcontext()
)
if self.rank != 0:
fsdp_model.flip_path()

View File

@ -183,7 +183,7 @@ class TestGradAcc(FSDPTest):
batch_idx = 0
for config in configs:
sync_context = (
fsdp_model.no_sync() if config.use_no_sync else contextlib.suppress()
fsdp_model.no_sync() if config.use_no_sync else contextlib.nullcontext()
)
with sync_context:
for _ in range(config.num_iters):

View File

@ -5,7 +5,7 @@ import os
import sys
import warnings
from collections import namedtuple
from contextlib import suppress
from contextlib import nullcontext
from copy import deepcopy
from typing import Any, Tuple
@ -371,7 +371,7 @@ class TestFSDPMisc(FSDPTest):
context = (
self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
if self.rank != 0
else suppress()
else nullcontext()
)
with context:
NestedWrappedModule.init(
@ -427,7 +427,7 @@ class TestFSDPMisc(FSDPTest):
)
)
if self.rank != 0
else suppress()
else nullcontext()
)
with context:
module = FSDP(no_params, device_id=0)

View File

@ -3,7 +3,7 @@
import io
import itertools
import sys
from contextlib import suppress
from contextlib import nullcontext
from copy import deepcopy
from functools import partial
from typing import Any, Dict
@ -899,7 +899,7 @@ class TestFSDPStateDict(FSDPTest):
def _create_module(wrap_fsdp=True):
LINEAR_SKIP = "linear_skip"
ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress()
ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else nullcontext()
with ctx:
module = SkipModel(double_nest=double_nest)
# Full name of linear_skip param tensors in SkipModel, as would be

View File

@ -522,7 +522,7 @@ class TestUnshardParams(TestUnshardParamsBase):
def _get_error_context(is_supported: bool):
return (
contextlib.suppress()
contextlib.nullcontext()
if is_supported
else self.assertRaises(NotImplementedError)
) # some configs are not implemented yet

View File

@ -9,7 +9,7 @@ import copy
import os
import sys
import unittest
from contextlib import suppress
from contextlib import nullcontext
from typing import Any, cast, List
import numpy as np
@ -301,7 +301,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
(list(m.parameters()), None), # `params` as a list
]
for ctor_input, error in ctor_inputs:
context = self.assertRaises(error) if error else suppress()
context = self.assertRaises(error) if error else nullcontext()
with context:
ZeroRedundancyOptimizer(
ctor_input,
@ -371,7 +371,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
@property
def context(self):
return (
suppress()
nullcontext()
if not torch.cuda.is_available()
else torch.cuda.device(self.rank)
)

View File

@ -7,7 +7,7 @@ import sys
import tempfile
import threading
import time
from contextlib import suppress
from contextlib import nullcontext
from dataclasses import dataclass
from datetime import timedelta
from itertools import product
@ -545,7 +545,7 @@ class CommonDistributedDataParallelTest:
process_group = self._get_process_group()
for use_bucket_view in (True, False):
err_ctx = (
suppress() if not use_reentrant else
nullcontext() if not use_reentrant else
self.assertRaisesRegex(
RuntimeError,
"Expected to mark a variable ready only once."
@ -577,7 +577,7 @@ class CommonDistributedDataParallelTest:
process_group = self._get_process_group()
for use_bucket_view in (True, False):
err_ctx = (
suppress() if not use_reentrant else
nullcontext() if not use_reentrant else
self.assertRaisesRegex(
RuntimeError,
"Expected to mark a variable ready only once."
@ -678,7 +678,7 @@ class CommonDistributedDataParallelTest:
self.assertRaisesRegex(
RuntimeError,
"Your training graph has changed in this iteration"
) if static_graph and not use_reentrant else suppress()
) if static_graph and not use_reentrant else nullcontext()
)
with err_ctx:
self._test_ddp_checkpointing(

View File

@ -5739,7 +5739,7 @@ for shape in [(1,), ()]:
"none of output has requires_grad=True"
)
if use_reentrant
else contextlib.suppress()
else contextlib.nullcontext()
)
a = torch.randn(2, 2, requires_grad=True)

View File

@ -899,7 +899,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
context = (
self._deregister_orig_params_ctx()
if self._use_orig_params
else contextlib.suppress()
else contextlib.nullcontext()
)
with context:
return super()._apply(*args, **kwargs)

View File

@ -916,7 +916,7 @@ def _get_should_profile():
def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
ctx_manager = contextlib.suppress()
ctx_manager = contextlib.nullcontext()
if should_profile:
# Create appropriate string representation based on type of func

View File

@ -5,7 +5,7 @@ import os
import re
import sys
from abc import ABC, abstractmethod
from contextlib import suppress
from contextlib import nullcontext
from copy import deepcopy
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
@ -117,7 +117,7 @@ def _zero_model(
summon_full=True,
):
"""Zeros the parameters and optionally buffers of ``model`` in place."""
ctx = FSDP.summon_full_params(model) if summon_full else suppress()
ctx = FSDP.summon_full_params(model) if summon_full else nullcontext()
with ctx:
for param in model.parameters():
with torch.no_grad():
@ -1121,7 +1121,7 @@ class FSDPTest(MultiProcessTestCase):
"has parameters on cuda",
)
if expects_device_error
else suppress()
else nullcontext()
)
with context:
fsdp_loss = self._train_for_several_steps(

View File

@ -251,7 +251,7 @@ class Trainer:
else:
input_batches = batches
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.suppress():
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext():
for b in input_batches:
with dist_autograd.context() as context_id:
output = self.hybrid_module.forward(b)

View File

@ -7,7 +7,7 @@ import sys
import tempfile
import time
from collections import namedtuple, OrderedDict
from contextlib import contextmanager, suppress
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from functools import reduce
from typing import Union, NamedTuple, Callable, Any
@ -1559,7 +1559,7 @@ class DistributedTest:
torch.cuda.set_device(device_id)
tensor = _build_tensor(rank + 1, device_id=device_id)
profiler_cls = profiler_ctx if profiler_ctx is not None else suppress()
profiler_cls = profiler_ctx if profiler_ctx is not None else nullcontext()
with profiler_cls as prof:
for src in range(0, world_size):
if src == rank:
@ -1629,7 +1629,7 @@ class DistributedTest:
rank = dist.get_rank()
send_size = rank + 1
tensor = _build_tensor(send_size)
ctx = profiler_ctx if profiler_ctx is not None else suppress()
ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
with ctx as prof:
for src in range(0, dist.get_world_size()):
if src == rank:
@ -1697,7 +1697,7 @@ class DistributedTest:
recv_ranks = list()
irecv_ranks = list()
ctx = profiler_ctx if profiler_ctx is not None else suppress()
ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
with ctx as prof:
for dst in range(0, dist.get_world_size()):
if dst == rank:
@ -1801,7 +1801,7 @@ class DistributedTest:
world_size = dist.get_world_size()
send_recv_size = 10
tensor = _build_tensor(send_recv_size, value=rank)
ctx = profiler_ctx if profiler_ctx is not None else suppress()
ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
with ctx as prof:
for dst in range(0, world_size):
if dst == rank:
@ -1859,7 +1859,7 @@ class DistributedTest:
def _test_isend(self, profiler_ctx):
rank = dist.get_rank()
world_size = dist.get_world_size()
ctx = profiler_ctx if profiler_ctx is not None else suppress()
ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
with ctx as prof:
if rank == 0:
requests = [
@ -7262,7 +7262,7 @@ class DistributedTest:
"Detected at least one rank that exhausted inputs.",
)
else:
exception_ctx = suppress()
exception_ctx = nullcontext()
with exception_ctx:
with net.join(
throw_on_early_termination=test_case.throw_on_early_termination
@ -7273,7 +7273,7 @@ class DistributedTest:
if i % sync_interval != 0:
context = net.no_sync()
else:
context = suppress()
context = nullcontext()
with context:
if isinstance(inp, tuple):
loss = net(*inp).sum()

View File

@ -59,11 +59,11 @@ def foo_add():
return torch.add(torch.ones(1), torch.ones(1))
def udf_with_torch_ops(device=-1, use_record_function=False):
device_ctx = contextlib.suppress() if device == -1 else torch.cuda.device(device)
device_ctx = contextlib.nullcontext() if device == -1 else torch.cuda.device(device)
record_function_ctx = (
torch.autograd.profiler.record_function("##forward##")
if use_record_function
else contextlib.suppress()
else contextlib.nullcontext()
)
with device_ctx, record_function_ctx:
t1, t2 = torch.ones(1), torch.ones(1)
@ -2164,7 +2164,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
if self.rank == 1:
with p() as prof:
record_function_ctx_mgr = (
contextlib.suppress()
contextlib.nullcontext()
if not use_record_function
else torch.autograd.profiler.record_function(
"foo"