[BE]: Enable ruff rule TRY302 and apply fixes (#101874)

Removes useless try statements and unreachable code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101874
Approved by: https://github.com/malfet
This commit is contained in:
Aaron Gokaslan 2023-05-19 17:30:47 +00:00 committed by PyTorch MergeBot
parent 1ac663d9f1
commit 3e2ea32dab
15 changed files with 99 additions and 141 deletions

View File

@ -1018,6 +1018,6 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'ruff==0.0.265',
'ruff==0.0.269',
]
is_formatter = true

View File

@ -66,6 +66,7 @@ select = [
"W",
# Not included in flake8
"PLE",
"TRY302",
]
[tool.ruff.per-file-ignores]

View File

@ -150,8 +150,6 @@ class FakeDDP(nn.Module):
DDP._active_ddp_module = self
try:
yield
except Exception:
raise
finally:
DDP._active_ddp_module = None

View File

@ -66,13 +66,9 @@ class Errors:
At the moment, only tests on "numpy.ndarray" are supported.
"""
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
try:
np.testing.assert_allclose(
x, y, rtol=self.rtol, atol=self.atol, equal_nan=True, verbose=True
)
except AssertionError as e:
raise
k(f"{colonize(msg)}{str(e).lstrip()}")
np.testing.assert_allclose(
x, y, rtol=self.rtol, atol=self.atol, equal_nan=True, verbose=True
)
else:
raise RuntimeError("Unsupported almost equal test")
@ -105,11 +101,7 @@ class Errors:
new_msg = f"{colonize(msg)}In embedded parameter '{x.name}'"
self.equalAndThen(t1, t2, new_msg, k)
elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
try:
np.testing.assert_equal(x, y)
except AssertionError as e:
raise
k("{}{}".format(colonize(msg, ": "), str(e).lstrip()))
np.testing.assert_equal(x, y)
else:
if x != y:
# TODO: Better algorithm for lists

View File

@ -559,15 +559,12 @@ class TestSymNumberMagicMethods(TestCase):
lambda_apply = getattr(operator, fn)
def guard_fn(v):
try:
if type(v) in (SymBool, bool):
return guard_bool(v)
elif type(v) in (SymFloat, float):
return guard_float(v)
else: # SymInt, int
return guard_int(v)
except Exception as e:
raise e
if type(v) in (SymBool, bool):
return guard_bool(v)
elif type(v) in (SymFloat, float):
return guard_float(v)
else: # SymInt, int
return guard_int(v)
# Get reference result
with maybe_xfail(inp1, inp2):

View File

@ -10487,52 +10487,47 @@ class TestConsistency(TestCaseMPS):
# Forward check
#
forward_failed = False
try:
mps_sample = cpu_sample.transform(
lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
mps_sample = cpu_sample.transform(
lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
cpu_kwargs = cpu_sample.kwargs
mps_args = [mps_sample.input] + list(mps_sample.args)
mps_kwargs = mps_sample.kwargs
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
cpu_kwargs = cpu_sample.kwargs
mps_args = [mps_sample.input] + list(mps_sample.args)
mps_kwargs = mps_sample.kwargs
# for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
mps_args[1] = cpu_args[1]
# for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
mps_args[1] = cpu_args[1]
cpu_out = op(*cpu_args, **cpu_kwargs)
mps_out = op(*mps_args, **mps_kwargs)
cpu_out = op(*cpu_args, **cpu_kwargs)
mps_out = op(*mps_args, **mps_kwargs)
if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype == torch.float32:
atol = 1e-4
rtol = 3e-5
elif op.name == "nn.functional.conv2d" or op.name == "linalg.multi_dot" and dtype == torch.float32:
atol = 1e-4
rtol = 3e-5
elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16:
atol = 1e-2
rtol = 1e-2
elif (op.name == "masked.mean"):
atol = 7e-4
rtol = 2e-3
elif (op.name == "native_layer_norm"):
atol = 1e-4
rtol = 1.3e-5
elif (op.name == "norm" or op.name == "linalg.norm") and dtype == torch.float16:
atol = 7e-4
rtol = 1.5e-3
elif op.name == "unique" and cpu_kwargs["sorted"] is False:
continue
else:
atol = None
rtol = None
if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype == torch.float32:
atol = 1e-4
rtol = 3e-5
elif op.name == "nn.functional.conv2d" or op.name == "linalg.multi_dot" and dtype == torch.float32:
atol = 1e-4
rtol = 3e-5
elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16:
atol = 1e-2
rtol = 1e-2
elif (op.name == "masked.mean"):
atol = 7e-4
rtol = 2e-3
elif (op.name == "native_layer_norm"):
atol = 1e-4
rtol = 1.3e-5
elif (op.name == "norm" or op.name == "linalg.norm") and dtype == torch.float16:
atol = 7e-4
rtol = 1.5e-3
elif op.name == "unique" and cpu_kwargs["sorted"] is False:
continue
else:
atol = None
rtol = None
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
except Exception as e:
raise e
forward_failed = True
all_forward_pass = False
#
# Backward check

View File

@ -46,8 +46,6 @@ def use_deterministic_algorithims(mode: bool, warn_only: bool):
try:
torch.use_deterministic_algorithms(mode, warn_only=warn_only)
yield {}
except RuntimeError as err:
raise err
finally:
torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only)

View File

@ -60,14 +60,11 @@ def get_disallowed_checksums(
"""
Return the set of disallowed checksums from all http_archive rules
"""
try:
# Use bazel to get the list of external dependencies in XML format
proc = subprocess.run(
[binary, "query", "kind(http_archive, //external:*)", "--output=xml"],
capture_output=True,
)
except OSError:
raise
# Use bazel to get the list of external dependencies in XML format
proc = subprocess.run(
[binary, "query", "kind(http_archive, //external:*)", "--output=xml"],
capture_output=True,
)
stdout = str(proc.stdout, "utf-8").strip()
root = ET.fromstring(stdout)

View File

@ -251,8 +251,6 @@ def sdp_kernel(enable_flash: bool = True, enable_math: bool = True, enable_mem_e
enable_mem_efficient_sdp(enable_mem_efficient)
enable_math_sdp(enable_math)
yield{}
except RuntimeError as err:
raise err
finally:
enable_flash_sdp(previous_flash)
enable_mem_efficient_sdp(previous_mem_efficient)

View File

@ -20,8 +20,6 @@ def _no_hook(module: nn.Module):
checkpoint.state(module).enable_hook = False
try:
yield
except Exception:
raise
finally:
checkpoint.state(module).enable_hook = orig_enable_hook

View File

@ -1674,48 +1674,43 @@ def _coalescing_manager(
if device:
group._start_coalescing(device)
cm = _CoalescingManager()
try:
yield cm
except Exception:
# Re-throw exception caught by code inside the context manager
raise
else:
op_list = _world.pg_coalesce_state.pop(group)
if op_list:
# Collectives supporting "Fast Path" coalescing are captured.
# See implementation in corresponding collective APIs.
# Currently supported:
# - coalesced `all_reduce`
# - coalesced `all_gather_into_tensor`
op0 = op_list[0].op
if op0 == all_reduce:
tensors = []
for op in op_list:
tensors.append(op.tensor)
opts = AllreduceCoalescedOptions()
opts.reduceOp = op_list[0].redop
work = group.allreduce_coalesced(tensors, opts)
elif op0 == all_gather_into_tensor:
inputs = []
outputs = []
for op in op_list:
inputs.append(op.tensor)
outputs.append(op.dst_tensor)
work = group.allgather_into_tensor_coalesced(outputs, inputs)
else:
raise AssertionError(
f"Coalescing manager does not support fast-path coalescing of {op0}, "
f"yet {op0} is still recorded in op list. This is an internal error of c10d."
)
if device:
# Old style of letting each coll inside the context manager to call into C++ counterpart via python binding
work = group._end_coalescing(device)
if async_ops:
cm.append(work)
yield cm
op_list = _world.pg_coalesce_state.pop(group)
if op_list:
# Collectives supporting "Fast Path" coalescing are captured.
# See implementation in corresponding collective APIs.
# Currently supported:
# - coalesced `all_reduce`
# - coalesced `all_gather_into_tensor`
op0 = op_list[0].op
if op0 == all_reduce:
tensors = []
for op in op_list:
tensors.append(op.tensor)
opts = AllreduceCoalescedOptions()
opts.reduceOp = op_list[0].redop
work = group.allreduce_coalesced(tensors, opts)
elif op0 == all_gather_into_tensor:
inputs = []
outputs = []
for op in op_list:
inputs.append(op.tensor)
outputs.append(op.dst_tensor)
work = group.allgather_into_tensor_coalesced(outputs, inputs)
else:
work.wait()
raise AssertionError(
f"Coalescing manager does not support fast-path coalescing of {op0}, "
f"yet {op0} is still recorded in op list. This is an internal error of c10d."
)
if device:
# Old style of letting each coll inside the context manager to call into C++ counterpart via python binding
work = group._end_coalescing(device)
if async_ops:
cm.append(work)
else:
work.wait()
def batch_isend_irecv(p2p_op_list):

View File

@ -635,12 +635,9 @@ def _sharded_pre_load_state_dict_hook(
def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator:
old_state_dict_config = fsdp_state._state_dict_config
old_state_dict_type = fsdp_state._state_dict_type
try:
fsdp_state._state_dict_config = FullStateDictConfig()
fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
yield
except Exception as e:
raise e
fsdp_state._state_dict_config = FullStateDictConfig()
fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
yield
fsdp_state._state_dict_config = old_state_dict_config
fsdp_state._state_dict_type = old_state_dict_type

View File

@ -750,16 +750,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
state_dict_config (Optional[StateDictConfig]): the configuration for the
target ``state_dict_type``.
"""
try:
prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
module,
state_dict_type,
state_dict_config,
optim_state_dict_config,
)
yield
except Exception as e:
raise e
prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
module,
state_dict_type,
state_dict_config,
optim_state_dict_config,
)
yield
FullyShardedDataParallel.set_state_dict_type(
module,
prev_state_dict_settings.state_dict_type,

View File

@ -1364,8 +1364,6 @@ class DistributedDataParallel(Module, Joinable):
DistributedDataParallel._active_ddp_module = self
try:
yield
except Exception:
raise
finally:
DistributedDataParallel._active_ddp_module = None

View File

@ -1587,10 +1587,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
og_func = rpc.api._wait_all_workers
def wait_all_workers_sleep(timeout):
try:
rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout)
except RuntimeError as ex:
raise ex
rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout)
rpc.api._wait_all_workers = wait_all_workers_sleep