mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Fix all B022 useless-contextlib-suppress (#100335)
No arguments passed to contextlib.suppress. No exceptions will be suppressed and therefore this context manager is redundant Pull Request resolved: https://github.com/pytorch/pytorch/pull/100335 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
d7bdfd3454
commit
01abbfbaae
|
|
@ -30,7 +30,6 @@ ignore = [
|
|||
"B007", "B008", "B017",
|
||||
"B018", # Useless expression
|
||||
"B019", "B020",
|
||||
"B022", # Allow empty context manager
|
||||
"B023", "B024", "B026",
|
||||
"B028", # No explicit `stacklevel` keyword argument found
|
||||
"B027", "B904", "B905",
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ class TestJoin(MultiProcessTestCase):
|
|||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
expected_msg
|
||||
) if throw_on_early_termination else contextlib.suppress():
|
||||
) if throw_on_early_termination else contextlib.nullcontext():
|
||||
with Join(
|
||||
allreducers,
|
||||
enable=enable,
|
||||
|
|
|
|||
|
|
@ -257,7 +257,7 @@ class TestFSDPCheckpoint(FSDPTest):
|
|||
offload_ctx = (
|
||||
get_patched_save_on_cpu()(pin_memory=True)
|
||||
if offload_activations
|
||||
else contextlib.suppress()
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with offload_ctx:
|
||||
out = checkpoint(m, inp, use_reentrant=True)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
from contextlib import suppress
|
||||
from contextlib import nullcontext
|
||||
from enum import auto, Enum
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
|
@ -76,7 +76,7 @@ class TestCommunication(FSDPTest):
|
|||
|
||||
def _run_iter(self, fsdp_model, batch, use_no_sync: bool):
|
||||
"""Runs an iteration inside or outside the ``no_sync()`` context."""
|
||||
context = fsdp_model.no_sync() if use_no_sync else suppress()
|
||||
context = fsdp_model.no_sync() if use_no_sync else nullcontext()
|
||||
with context:
|
||||
output = fsdp_model(*batch)
|
||||
loss = fsdp_model.module.get_loss(batch, output)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import suppress
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
|
@ -159,7 +159,7 @@ class TestFSDPExecOrder(FSDPTest):
|
|||
expected_regex=regex,
|
||||
)
|
||||
if self.rank != 0
|
||||
else suppress()
|
||||
else nullcontext()
|
||||
)
|
||||
if self.rank != 0:
|
||||
fsdp_model.flip_path()
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ class TestGradAcc(FSDPTest):
|
|||
batch_idx = 0
|
||||
for config in configs:
|
||||
sync_context = (
|
||||
fsdp_model.no_sync() if config.use_no_sync else contextlib.suppress()
|
||||
fsdp_model.no_sync() if config.use_no_sync else contextlib.nullcontext()
|
||||
)
|
||||
with sync_context:
|
||||
for _ in range(config.num_iters):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import os
|
|||
import sys
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from contextlib import suppress
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from typing import Any, Tuple
|
||||
|
||||
|
|
@ -371,7 +371,7 @@ class TestFSDPMisc(FSDPTest):
|
|||
context = (
|
||||
self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
|
||||
if self.rank != 0
|
||||
else suppress()
|
||||
else nullcontext()
|
||||
)
|
||||
with context:
|
||||
NestedWrappedModule.init(
|
||||
|
|
@ -427,7 +427,7 @@ class TestFSDPMisc(FSDPTest):
|
|||
)
|
||||
)
|
||||
if self.rank != 0
|
||||
else suppress()
|
||||
else nullcontext()
|
||||
)
|
||||
with context:
|
||||
module = FSDP(no_params, device_id=0)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import io
|
||||
import itertools
|
||||
import sys
|
||||
from contextlib import suppress
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Dict
|
||||
|
|
@ -899,7 +899,7 @@ class TestFSDPStateDict(FSDPTest):
|
|||
|
||||
def _create_module(wrap_fsdp=True):
|
||||
LINEAR_SKIP = "linear_skip"
|
||||
ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress()
|
||||
ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else nullcontext()
|
||||
with ctx:
|
||||
module = SkipModel(double_nest=double_nest)
|
||||
# Full name of linear_skip param tensors in SkipModel, as would be
|
||||
|
|
|
|||
|
|
@ -522,7 +522,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
|
||||
def _get_error_context(is_supported: bool):
|
||||
return (
|
||||
contextlib.suppress()
|
||||
contextlib.nullcontext()
|
||||
if is_supported
|
||||
else self.assertRaises(NotImplementedError)
|
||||
) # some configs are not implemented yet
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import copy
|
|||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from contextlib import suppress
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, cast, List
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -301,7 +301,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||
(list(m.parameters()), None), # `params` as a list
|
||||
]
|
||||
for ctor_input, error in ctor_inputs:
|
||||
context = self.assertRaises(error) if error else suppress()
|
||||
context = self.assertRaises(error) if error else nullcontext()
|
||||
with context:
|
||||
ZeroRedundancyOptimizer(
|
||||
ctor_input,
|
||||
|
|
@ -371,7 +371,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||
@property
|
||||
def context(self):
|
||||
return (
|
||||
suppress()
|
||||
nullcontext()
|
||||
if not torch.cuda.is_available()
|
||||
else torch.cuda.device(self.rank)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import sys
|
|||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from itertools import product
|
||||
|
|
@ -545,7 +545,7 @@ class CommonDistributedDataParallelTest:
|
|||
process_group = self._get_process_group()
|
||||
for use_bucket_view in (True, False):
|
||||
err_ctx = (
|
||||
suppress() if not use_reentrant else
|
||||
nullcontext() if not use_reentrant else
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected to mark a variable ready only once."
|
||||
|
|
@ -577,7 +577,7 @@ class CommonDistributedDataParallelTest:
|
|||
process_group = self._get_process_group()
|
||||
for use_bucket_view in (True, False):
|
||||
err_ctx = (
|
||||
suppress() if not use_reentrant else
|
||||
nullcontext() if not use_reentrant else
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected to mark a variable ready only once."
|
||||
|
|
@ -678,7 +678,7 @@ class CommonDistributedDataParallelTest:
|
|||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Your training graph has changed in this iteration"
|
||||
) if static_graph and not use_reentrant else suppress()
|
||||
) if static_graph and not use_reentrant else nullcontext()
|
||||
)
|
||||
with err_ctx:
|
||||
self._test_ddp_checkpointing(
|
||||
|
|
|
|||
|
|
@ -5739,7 +5739,7 @@ for shape in [(1,), ()]:
|
|||
"none of output has requires_grad=True"
|
||||
)
|
||||
if use_reentrant
|
||||
else contextlib.suppress()
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
a = torch.randn(2, 2, requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -899,7 +899,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
context = (
|
||||
self._deregister_orig_params_ctx()
|
||||
if self._use_orig_params
|
||||
else contextlib.suppress()
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with context:
|
||||
return super()._apply(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -916,7 +916,7 @@ def _get_should_profile():
|
|||
|
||||
|
||||
def _enable_rpc_profiler(should_profile, qualified_name, func, rpc_type, dst_worker_info):
|
||||
ctx_manager = contextlib.suppress()
|
||||
ctx_manager = contextlib.nullcontext()
|
||||
|
||||
if should_profile:
|
||||
# Create appropriate string representation based on type of func
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import os
|
|||
import re
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import suppress
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
|
@ -117,7 +117,7 @@ def _zero_model(
|
|||
summon_full=True,
|
||||
):
|
||||
"""Zeros the parameters and optionally buffers of ``model`` in place."""
|
||||
ctx = FSDP.summon_full_params(model) if summon_full else suppress()
|
||||
ctx = FSDP.summon_full_params(model) if summon_full else nullcontext()
|
||||
with ctx:
|
||||
for param in model.parameters():
|
||||
with torch.no_grad():
|
||||
|
|
@ -1121,7 +1121,7 @@ class FSDPTest(MultiProcessTestCase):
|
|||
"has parameters on cuda",
|
||||
)
|
||||
if expects_device_error
|
||||
else suppress()
|
||||
else nullcontext()
|
||||
)
|
||||
with context:
|
||||
fsdp_loss = self._train_for_several_steps(
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ class Trainer:
|
|||
else:
|
||||
input_batches = batches
|
||||
|
||||
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.suppress():
|
||||
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext():
|
||||
for b in input_batches:
|
||||
with dist_autograd.context() as context_id:
|
||||
output = self.hybrid_module.forward(b)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import sys
|
|||
import tempfile
|
||||
import time
|
||||
from collections import namedtuple, OrderedDict
|
||||
from contextlib import contextmanager, suppress
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from datetime import timedelta
|
||||
from functools import reduce
|
||||
from typing import Union, NamedTuple, Callable, Any
|
||||
|
|
@ -1559,7 +1559,7 @@ class DistributedTest:
|
|||
torch.cuda.set_device(device_id)
|
||||
|
||||
tensor = _build_tensor(rank + 1, device_id=device_id)
|
||||
profiler_cls = profiler_ctx if profiler_ctx is not None else suppress()
|
||||
profiler_cls = profiler_ctx if profiler_ctx is not None else nullcontext()
|
||||
with profiler_cls as prof:
|
||||
for src in range(0, world_size):
|
||||
if src == rank:
|
||||
|
|
@ -1629,7 +1629,7 @@ class DistributedTest:
|
|||
rank = dist.get_rank()
|
||||
send_size = rank + 1
|
||||
tensor = _build_tensor(send_size)
|
||||
ctx = profiler_ctx if profiler_ctx is not None else suppress()
|
||||
ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
|
||||
with ctx as prof:
|
||||
for src in range(0, dist.get_world_size()):
|
||||
if src == rank:
|
||||
|
|
@ -1697,7 +1697,7 @@ class DistributedTest:
|
|||
recv_ranks = list()
|
||||
irecv_ranks = list()
|
||||
|
||||
ctx = profiler_ctx if profiler_ctx is not None else suppress()
|
||||
ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
|
||||
with ctx as prof:
|
||||
for dst in range(0, dist.get_world_size()):
|
||||
if dst == rank:
|
||||
|
|
@ -1801,7 +1801,7 @@ class DistributedTest:
|
|||
world_size = dist.get_world_size()
|
||||
send_recv_size = 10
|
||||
tensor = _build_tensor(send_recv_size, value=rank)
|
||||
ctx = profiler_ctx if profiler_ctx is not None else suppress()
|
||||
ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
|
||||
with ctx as prof:
|
||||
for dst in range(0, world_size):
|
||||
if dst == rank:
|
||||
|
|
@ -1859,7 +1859,7 @@ class DistributedTest:
|
|||
def _test_isend(self, profiler_ctx):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
ctx = profiler_ctx if profiler_ctx is not None else suppress()
|
||||
ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
|
||||
with ctx as prof:
|
||||
if rank == 0:
|
||||
requests = [
|
||||
|
|
@ -7262,7 +7262,7 @@ class DistributedTest:
|
|||
"Detected at least one rank that exhausted inputs.",
|
||||
)
|
||||
else:
|
||||
exception_ctx = suppress()
|
||||
exception_ctx = nullcontext()
|
||||
with exception_ctx:
|
||||
with net.join(
|
||||
throw_on_early_termination=test_case.throw_on_early_termination
|
||||
|
|
@ -7273,7 +7273,7 @@ class DistributedTest:
|
|||
if i % sync_interval != 0:
|
||||
context = net.no_sync()
|
||||
else:
|
||||
context = suppress()
|
||||
context = nullcontext()
|
||||
with context:
|
||||
if isinstance(inp, tuple):
|
||||
loss = net(*inp).sum()
|
||||
|
|
|
|||
|
|
@ -59,11 +59,11 @@ def foo_add():
|
|||
return torch.add(torch.ones(1), torch.ones(1))
|
||||
|
||||
def udf_with_torch_ops(device=-1, use_record_function=False):
|
||||
device_ctx = contextlib.suppress() if device == -1 else torch.cuda.device(device)
|
||||
device_ctx = contextlib.nullcontext() if device == -1 else torch.cuda.device(device)
|
||||
record_function_ctx = (
|
||||
torch.autograd.profiler.record_function("##forward##")
|
||||
if use_record_function
|
||||
else contextlib.suppress()
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with device_ctx, record_function_ctx:
|
||||
t1, t2 = torch.ones(1), torch.ones(1)
|
||||
|
|
@ -2164,7 +2164,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
|
|||
if self.rank == 1:
|
||||
with p() as prof:
|
||||
record_function_ctx_mgr = (
|
||||
contextlib.suppress()
|
||||
contextlib.nullcontext()
|
||||
if not use_record_function
|
||||
else torch.autograd.profiler.record_function(
|
||||
"foo"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user