mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[BE]: Enable ruff rule TRY302 and apply fixes (#101874)
Removes useless try statements and unreachable code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/101874 Approved by: https://github.com/malfet
This commit is contained in:
parent
1ac663d9f1
commit
3e2ea32dab
|
|
@ -1018,6 +1018,6 @@ init_command = [
|
|||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'ruff==0.0.265',
|
||||
'ruff==0.0.269',
|
||||
]
|
||||
is_formatter = true
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ select = [
|
|||
"W",
|
||||
# Not included in flake8
|
||||
"PLE",
|
||||
"TRY302",
|
||||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
|
|
|
|||
|
|
@ -150,8 +150,6 @@ class FakeDDP(nn.Module):
|
|||
DDP._active_ddp_module = self
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
DDP._active_ddp_module = None
|
||||
|
||||
|
|
|
|||
|
|
@ -66,13 +66,9 @@ class Errors:
|
|||
At the moment, only tests on "numpy.ndarray" are supported.
|
||||
"""
|
||||
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
|
||||
try:
|
||||
np.testing.assert_allclose(
|
||||
x, y, rtol=self.rtol, atol=self.atol, equal_nan=True, verbose=True
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise
|
||||
k(f"{colonize(msg)}{str(e).lstrip()}")
|
||||
np.testing.assert_allclose(
|
||||
x, y, rtol=self.rtol, atol=self.atol, equal_nan=True, verbose=True
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unsupported almost equal test")
|
||||
|
||||
|
|
@ -105,11 +101,7 @@ class Errors:
|
|||
new_msg = f"{colonize(msg)}In embedded parameter '{x.name}'"
|
||||
self.equalAndThen(t1, t2, new_msg, k)
|
||||
elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
|
||||
try:
|
||||
np.testing.assert_equal(x, y)
|
||||
except AssertionError as e:
|
||||
raise
|
||||
k("{}{}".format(colonize(msg, ": "), str(e).lstrip()))
|
||||
np.testing.assert_equal(x, y)
|
||||
else:
|
||||
if x != y:
|
||||
# TODO: Better algorithm for lists
|
||||
|
|
|
|||
|
|
@ -559,15 +559,12 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
lambda_apply = getattr(operator, fn)
|
||||
|
||||
def guard_fn(v):
|
||||
try:
|
||||
if type(v) in (SymBool, bool):
|
||||
return guard_bool(v)
|
||||
elif type(v) in (SymFloat, float):
|
||||
return guard_float(v)
|
||||
else: # SymInt, int
|
||||
return guard_int(v)
|
||||
except Exception as e:
|
||||
raise e
|
||||
if type(v) in (SymBool, bool):
|
||||
return guard_bool(v)
|
||||
elif type(v) in (SymFloat, float):
|
||||
return guard_float(v)
|
||||
else: # SymInt, int
|
||||
return guard_int(v)
|
||||
|
||||
# Get reference result
|
||||
with maybe_xfail(inp1, inp2):
|
||||
|
|
|
|||
|
|
@ -10487,52 +10487,47 @@ class TestConsistency(TestCaseMPS):
|
|||
# Forward check
|
||||
#
|
||||
forward_failed = False
|
||||
try:
|
||||
mps_sample = cpu_sample.transform(
|
||||
lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
|
||||
mps_sample = cpu_sample.transform(
|
||||
lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
|
||||
|
||||
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
|
||||
cpu_kwargs = cpu_sample.kwargs
|
||||
mps_args = [mps_sample.input] + list(mps_sample.args)
|
||||
mps_kwargs = mps_sample.kwargs
|
||||
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
|
||||
cpu_kwargs = cpu_sample.kwargs
|
||||
mps_args = [mps_sample.input] + list(mps_sample.args)
|
||||
mps_kwargs = mps_sample.kwargs
|
||||
|
||||
# for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
|
||||
if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
|
||||
mps_args[1] = cpu_args[1]
|
||||
# for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
|
||||
if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
|
||||
mps_args[1] = cpu_args[1]
|
||||
|
||||
cpu_out = op(*cpu_args, **cpu_kwargs)
|
||||
mps_out = op(*mps_args, **mps_kwargs)
|
||||
cpu_out = op(*cpu_args, **cpu_kwargs)
|
||||
mps_out = op(*mps_args, **mps_kwargs)
|
||||
|
||||
if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype == torch.float32:
|
||||
atol = 1e-4
|
||||
rtol = 3e-5
|
||||
elif op.name == "nn.functional.conv2d" or op.name == "linalg.multi_dot" and dtype == torch.float32:
|
||||
atol = 1e-4
|
||||
rtol = 3e-5
|
||||
elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16:
|
||||
atol = 1e-2
|
||||
rtol = 1e-2
|
||||
elif (op.name == "masked.mean"):
|
||||
atol = 7e-4
|
||||
rtol = 2e-3
|
||||
elif (op.name == "native_layer_norm"):
|
||||
atol = 1e-4
|
||||
rtol = 1.3e-5
|
||||
elif (op.name == "norm" or op.name == "linalg.norm") and dtype == torch.float16:
|
||||
atol = 7e-4
|
||||
rtol = 1.5e-3
|
||||
elif op.name == "unique" and cpu_kwargs["sorted"] is False:
|
||||
continue
|
||||
else:
|
||||
atol = None
|
||||
rtol = None
|
||||
if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype == torch.float32:
|
||||
atol = 1e-4
|
||||
rtol = 3e-5
|
||||
elif op.name == "nn.functional.conv2d" or op.name == "linalg.multi_dot" and dtype == torch.float32:
|
||||
atol = 1e-4
|
||||
rtol = 3e-5
|
||||
elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16:
|
||||
atol = 1e-2
|
||||
rtol = 1e-2
|
||||
elif (op.name == "masked.mean"):
|
||||
atol = 7e-4
|
||||
rtol = 2e-3
|
||||
elif (op.name == "native_layer_norm"):
|
||||
atol = 1e-4
|
||||
rtol = 1.3e-5
|
||||
elif (op.name == "norm" or op.name == "linalg.norm") and dtype == torch.float16:
|
||||
atol = 7e-4
|
||||
rtol = 1.5e-3
|
||||
elif op.name == "unique" and cpu_kwargs["sorted"] is False:
|
||||
continue
|
||||
else:
|
||||
atol = None
|
||||
rtol = None
|
||||
|
||||
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
|
||||
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
forward_failed = True
|
||||
all_forward_pass = False
|
||||
|
||||
#
|
||||
# Backward check
|
||||
|
|
|
|||
|
|
@ -46,8 +46,6 @@ def use_deterministic_algorithims(mode: bool, warn_only: bool):
|
|||
try:
|
||||
torch.use_deterministic_algorithms(mode, warn_only=warn_only)
|
||||
yield {}
|
||||
except RuntimeError as err:
|
||||
raise err
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only)
|
||||
|
||||
|
|
|
|||
|
|
@ -60,14 +60,11 @@ def get_disallowed_checksums(
|
|||
"""
|
||||
Return the set of disallowed checksums from all http_archive rules
|
||||
"""
|
||||
try:
|
||||
# Use bazel to get the list of external dependencies in XML format
|
||||
proc = subprocess.run(
|
||||
[binary, "query", "kind(http_archive, //external:*)", "--output=xml"],
|
||||
capture_output=True,
|
||||
)
|
||||
except OSError:
|
||||
raise
|
||||
# Use bazel to get the list of external dependencies in XML format
|
||||
proc = subprocess.run(
|
||||
[binary, "query", "kind(http_archive, //external:*)", "--output=xml"],
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
stdout = str(proc.stdout, "utf-8").strip()
|
||||
root = ET.fromstring(stdout)
|
||||
|
|
|
|||
|
|
@ -251,8 +251,6 @@ def sdp_kernel(enable_flash: bool = True, enable_math: bool = True, enable_mem_e
|
|||
enable_mem_efficient_sdp(enable_mem_efficient)
|
||||
enable_math_sdp(enable_math)
|
||||
yield{}
|
||||
except RuntimeError as err:
|
||||
raise err
|
||||
finally:
|
||||
enable_flash_sdp(previous_flash)
|
||||
enable_mem_efficient_sdp(previous_mem_efficient)
|
||||
|
|
|
|||
|
|
@ -20,8 +20,6 @@ def _no_hook(module: nn.Module):
|
|||
checkpoint.state(module).enable_hook = False
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
checkpoint.state(module).enable_hook = orig_enable_hook
|
||||
|
||||
|
|
|
|||
|
|
@ -1674,48 +1674,43 @@ def _coalescing_manager(
|
|||
if device:
|
||||
group._start_coalescing(device)
|
||||
cm = _CoalescingManager()
|
||||
try:
|
||||
yield cm
|
||||
except Exception:
|
||||
# Re-throw exception caught by code inside the context manager
|
||||
raise
|
||||
else:
|
||||
op_list = _world.pg_coalesce_state.pop(group)
|
||||
if op_list:
|
||||
# Collectives supporting "Fast Path" coalescing are captured.
|
||||
# See implementation in corresponding collective APIs.
|
||||
# Currently supported:
|
||||
# - coalesced `all_reduce`
|
||||
# - coalesced `all_gather_into_tensor`
|
||||
op0 = op_list[0].op
|
||||
if op0 == all_reduce:
|
||||
tensors = []
|
||||
for op in op_list:
|
||||
tensors.append(op.tensor)
|
||||
opts = AllreduceCoalescedOptions()
|
||||
opts.reduceOp = op_list[0].redop
|
||||
work = group.allreduce_coalesced(tensors, opts)
|
||||
elif op0 == all_gather_into_tensor:
|
||||
inputs = []
|
||||
outputs = []
|
||||
for op in op_list:
|
||||
inputs.append(op.tensor)
|
||||
outputs.append(op.dst_tensor)
|
||||
work = group.allgather_into_tensor_coalesced(outputs, inputs)
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"Coalescing manager does not support fast-path coalescing of {op0}, "
|
||||
f"yet {op0} is still recorded in op list. This is an internal error of c10d."
|
||||
)
|
||||
|
||||
if device:
|
||||
# Old style of letting each coll inside the context manager to call into C++ counterpart via python binding
|
||||
work = group._end_coalescing(device)
|
||||
|
||||
if async_ops:
|
||||
cm.append(work)
|
||||
yield cm
|
||||
op_list = _world.pg_coalesce_state.pop(group)
|
||||
if op_list:
|
||||
# Collectives supporting "Fast Path" coalescing are captured.
|
||||
# See implementation in corresponding collective APIs.
|
||||
# Currently supported:
|
||||
# - coalesced `all_reduce`
|
||||
# - coalesced `all_gather_into_tensor`
|
||||
op0 = op_list[0].op
|
||||
if op0 == all_reduce:
|
||||
tensors = []
|
||||
for op in op_list:
|
||||
tensors.append(op.tensor)
|
||||
opts = AllreduceCoalescedOptions()
|
||||
opts.reduceOp = op_list[0].redop
|
||||
work = group.allreduce_coalesced(tensors, opts)
|
||||
elif op0 == all_gather_into_tensor:
|
||||
inputs = []
|
||||
outputs = []
|
||||
for op in op_list:
|
||||
inputs.append(op.tensor)
|
||||
outputs.append(op.dst_tensor)
|
||||
work = group.allgather_into_tensor_coalesced(outputs, inputs)
|
||||
else:
|
||||
work.wait()
|
||||
raise AssertionError(
|
||||
f"Coalescing manager does not support fast-path coalescing of {op0}, "
|
||||
f"yet {op0} is still recorded in op list. This is an internal error of c10d."
|
||||
)
|
||||
|
||||
if device:
|
||||
# Old style of letting each coll inside the context manager to call into C++ counterpart via python binding
|
||||
work = group._end_coalescing(device)
|
||||
|
||||
if async_ops:
|
||||
cm.append(work)
|
||||
else:
|
||||
work.wait()
|
||||
|
||||
|
||||
def batch_isend_irecv(p2p_op_list):
|
||||
|
|
|
|||
|
|
@ -635,12 +635,9 @@ def _sharded_pre_load_state_dict_hook(
|
|||
def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator:
|
||||
old_state_dict_config = fsdp_state._state_dict_config
|
||||
old_state_dict_type = fsdp_state._state_dict_type
|
||||
try:
|
||||
fsdp_state._state_dict_config = FullStateDictConfig()
|
||||
fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
|
||||
yield
|
||||
except Exception as e:
|
||||
raise e
|
||||
fsdp_state._state_dict_config = FullStateDictConfig()
|
||||
fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
|
||||
yield
|
||||
fsdp_state._state_dict_config = old_state_dict_config
|
||||
fsdp_state._state_dict_type = old_state_dict_type
|
||||
|
||||
|
|
|
|||
|
|
@ -750,16 +750,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
state_dict_config (Optional[StateDictConfig]): the configuration for the
|
||||
target ``state_dict_type``.
|
||||
"""
|
||||
try:
|
||||
prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
|
||||
module,
|
||||
state_dict_type,
|
||||
state_dict_config,
|
||||
optim_state_dict_config,
|
||||
)
|
||||
yield
|
||||
except Exception as e:
|
||||
raise e
|
||||
prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
|
||||
module,
|
||||
state_dict_type,
|
||||
state_dict_config,
|
||||
optim_state_dict_config,
|
||||
)
|
||||
yield
|
||||
FullyShardedDataParallel.set_state_dict_type(
|
||||
module,
|
||||
prev_state_dict_settings.state_dict_type,
|
||||
|
|
|
|||
|
|
@ -1364,8 +1364,6 @@ class DistributedDataParallel(Module, Joinable):
|
|||
DistributedDataParallel._active_ddp_module = self
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
DistributedDataParallel._active_ddp_module = None
|
||||
|
||||
|
|
|
|||
|
|
@ -1587,10 +1587,7 @@ class RpcTest(RpcAgentTestFixture, RpcTestCommon):
|
|||
og_func = rpc.api._wait_all_workers
|
||||
|
||||
def wait_all_workers_sleep(timeout):
|
||||
try:
|
||||
rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout)
|
||||
except RuntimeError as ex:
|
||||
raise ex
|
||||
rpc.api._all_gather(SlowPickleClass(0.5), timeout=timeout)
|
||||
|
||||
rpc.api._wait_all_workers = wait_all_workers_sleep
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user