[BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)

Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
This commit is contained in:
Xuehai Pan 2024-04-17 19:29:30 +00:00 committed by PyTorch MergeBot
parent b726a23d4e
commit 93e249969b
98 changed files with 297 additions and 296 deletions

View File

@ -1858,7 +1858,7 @@ class TimeOutException(Exception):
def alarm_handler(signum, frame):
raise TimeOutException()
raise TimeOutException
def exit_after(s):
@ -2136,7 +2136,7 @@ class BenchmarkRunner:
return set()
def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
raise NotImplementedError()
raise NotImplementedError
@property
def equal_nan(self):

View File

@ -80,7 +80,7 @@ def serialize_sparse_tensor(e):
def deserialize_sparse_tensor(size, dtype, layout, is_coalesced, nnz=None):
raise NotImplementedError()
raise NotImplementedError
def deserialize_tensor(size, dtype, stride=None):

View File

@ -126,6 +126,7 @@ select = [
"PT025",
"PT026",
"PYI",
"RSE",
"RUF008", # mutable dataclass default
"RUF015", # access first ele in constant time
"RUF016", # type error non-integer index

View File

@ -578,7 +578,7 @@ def main():
with open(filename, "w") as f:
f.writelines(lines)
return
raise AssertionError()
raise AssertionError
if __name__ == "__main__":

View File

@ -6,10 +6,10 @@ import torch
class Setup:
def setup(self):
raise NotImplementedError()
raise NotImplementedError
def shutdown(self):
raise NotImplementedError()
raise NotImplementedError
class FileSetup:

View File

@ -21,7 +21,7 @@ class Dummymodel(nn.Module):
super().__init__()
def forward(self, x):
raise NotImplementedError()
raise NotImplementedError
class EPModel(nn.Module):
@ -31,7 +31,7 @@ class EPModel(nn.Module):
self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
def forward(self, x):
raise NotImplementedError()
raise NotImplementedError
class SecondTier(nn.Module):
@ -43,7 +43,7 @@ class SecondTier(nn.Module):
self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
def forward(self, x):
raise NotImplementedError()
raise NotImplementedError
class TopModel(nn.Module):
@ -55,7 +55,7 @@ class TopModel(nn.Module):
self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
def forward(self, x):
raise NotImplementedError()
raise NotImplementedError
class TestFSDPWithEP(DTensorTestBase, VerifyStateDictMixin):

View File

@ -34,7 +34,7 @@ class TestMetricsHandler(MetricHandler):
class Parent(abc.ABC):
@abc.abstractmethod
def func(self):
raise NotImplementedError()
raise NotImplementedError
def base_func(self):
self.func()
@ -57,7 +57,7 @@ class MetricsApiTest(TestCase):
@prof
def throw(self):
raise RuntimeError()
raise RuntimeError
@prof(group="torchelastic")
def bar2(self):

View File

@ -197,7 +197,7 @@ class _DummyRendezvousHandler(RendezvousHandler):
return "dummy_backend"
def next_rendezvous(self) -> Tuple[Store, int, int]:
raise NotImplementedError()
raise NotImplementedError
def is_closed(self) -> bool:
return False

View File

@ -38,7 +38,7 @@ def _create_c10d_store_mp(is_server, server_addr, port, world_size, wait_for_wor
timeout=2,
)
if store is None:
raise AssertionError()
raise AssertionError
store.set(f"test_key/{os.getpid()}", b"test_value")

View File

@ -363,7 +363,7 @@ class TestFSDPOptimState(FSDPTest):
# these settings are not implemented since the transformer is
# wrapped with FSDP at the top-level, which means that there is
# only a single flat parameter, making these booleans vacuous
raise NotImplementedError()
raise NotImplementedError
if group is None:
group = dist.distributed_c10d._get_default_group()
model = TransformerWithSharedParams.init(

View File

@ -63,7 +63,7 @@ def test_exception_no_hang(setup_rpc):
class Raise(nn.Module):
def forward(self, x):
raise ExpectedException()
raise ExpectedException
model = nn.Sequential(Pass(), Pass(), Raise())
model = Pipe(model, chunks=3)

View File

@ -231,7 +231,7 @@ def test_exception(setup_rpc):
class Raise(nn.Module):
def forward(self, *_):
raise ExpectedException()
raise ExpectedException
model = nn.Sequential(Raise())
model = Pipe(model, chunks=1)
@ -265,7 +265,7 @@ def test_exception_early_stop_asap(setup_rpc):
class Raise(nn.Module):
def forward(self, x):
raise ExpectedException()
raise ExpectedException
model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
model = Pipe(model, chunks=3)

View File

@ -752,7 +752,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
new_data = args[0]._data.view(*args[1:])
return FooTensor(new_data, args[0]._config, args[0]._scale)
raise NotImplementedError()
raise NotImplementedError
class foo_autograd_fn(torch.autograd.Function):
@staticmethod

View File

@ -47,7 +47,7 @@ from user code:
def test_internal_error_suppress_errors(self, records):
def fn001(x):
def f(ctx):
raise AssertionError()
raise AssertionError
comptime(f)
@ -62,7 +62,7 @@ WON'T CONVERT fn001 test_exc.py line N
========== TorchDynamo Stack Trace ==========
Traceback (most recent call last):
File "test_exc.py", line N, in f
raise AssertionError()
raise AssertionError
AssertionError:
from user code:
@ -84,7 +84,7 @@ from user code:
def test_not_implemented_error(self, records):
def fn001(x):
def f(ctx):
raise NotImplementedError()
raise NotImplementedError
# Ensure graph break is not possible
for i in range(3):
@ -101,7 +101,7 @@ WON'T CONVERT fn001 test_exc.py line N
due to:
Traceback (most recent call last):
File "test_exc.py", line N, in f
raise NotImplementedError()
raise NotImplementedError
torch._dynamo.exc.InternalTorchDynamoError:
from user code:
@ -128,7 +128,7 @@ from user code:
# NB: avoid decorator, as 3.11 changed the line number attributed
# in this situation
def f(ctx):
raise AssertionError()
raise AssertionError
comptime(f)

View File

@ -164,7 +164,7 @@ from user code:
import torch._inductor.lowering
def throw(x):
raise AssertionError()
raise AssertionError
# inject an error in the lowerings
dict_entries = {}
@ -189,7 +189,7 @@ WON'T CONVERT inductor_error_fn test_logging.py line N
due to:
Traceback (most recent call last):
File "test_logging.py", line N, in throw
raise AssertionError()
raise AssertionError
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError:
target: aten.round.default

View File

@ -396,12 +396,12 @@ class ListConfig:
if self.resolve:
x = x._dereference_node()
if x._is_missing():
raise AssertionError()
raise AssertionError
self.index = self.index + 1
if isinstance(x, ListConfig.ValueNode):
return x._value()
raise AssertionError()
raise AssertionError
def __iter__(self):
return self._iter_ex(True)
@ -410,7 +410,7 @@ class ListConfig:
try:
return ListConfig.ListIterator(self, resolve)
except Exception:
raise AssertionError()
raise AssertionError
def __init__(self):
self._content = [
@ -545,7 +545,7 @@ def apply_chunking_to_forward(forward_fn, *input_tensors):
assert all(input_tensor.shape[1] == tensor_shape for input_tensor in input_tensors)
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
if num_args_in_forward_chunk_fn != len(input_tensors):
raise ValueError()
raise ValueError
return forward_fn(*input_tensors)
@ -848,7 +848,7 @@ def _merge_criteria_processor_list(default_list, custom_list):
for default in default_list:
for custom in custom_list:
if type(custom) is type(default):
raise ValueError()
raise ValueError
default_list.extend(custom_list)
return default_list
@ -2573,7 +2573,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
if self.i < 3:
self.i += 1
return self.i
raise StopIteration()
raise StopIteration
@torch.compile(backend="eager", fullgraph=True)
def fn(x):

View File

@ -147,10 +147,10 @@ class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
class Foo(list):
def __iter__(self):
raise Exception()
raise Exception
def __len__(self):
raise Exception()
raise Exception
x = Foo()
x.append(torch.randn(4))

View File

@ -233,7 +233,7 @@ class StructuredTraceTest(TestCase):
import torch._inductor.lowering
def throw(x):
raise AssertionError()
raise AssertionError
# inject an error in the lowerings
dict_entries = {}

View File

@ -732,12 +732,12 @@ class Operator:
def any_opinfo_attr(self, attr):
if not self.has_opinfo():
raise RuntimeError()
raise RuntimeError
return any(getattr(opinfo, attr) for opinfo in self.opinfos)
def all_opinfo_attr(self, attr):
if not self.has_opinfo():
raise RuntimeError()
raise RuntimeError
return all(getattr(opinfo, attr) for opinfo in self.opinfos)
def supports_vjp(self):
@ -870,7 +870,7 @@ class OperatorSet:
elif n.startswith(torch_dot):
names_sanitized.append(n[len(torch_dot) :])
else:
raise AssertionError()
raise AssertionError
return cls.from_names(names_sanitized)
def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)):

View File

@ -403,7 +403,7 @@ class TestMin(TestCase):
# test with too many elements
try:
A[1, ..., 1, 1]
raise NotImplementedError()
raise NotImplementedError
except IndexError:
pass
c, d = dims()
@ -415,7 +415,7 @@ class TestMin(TestCase):
)
try:
A[..., 3, ...]
raise NotImplementedError()
raise NotImplementedError
except DimensionBindError:
pass

View File

@ -285,7 +285,7 @@ class TestInductorDynamic(TestCase):
@custom_ops.custom_op("test::foo")
def foo(x: torch.Tensor, y: int) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl("test::foo")
def foo_impl(x: torch.Tensor, y: int) -> torch.Tensor:
@ -401,7 +401,7 @@ class TestInductorDynamic(TestCase):
@custom_ops.custom_op("test::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl("test::foo")
def foo_impl(x: torch.Tensor) -> torch.Tensor:

View File

@ -2632,7 +2632,7 @@ class TestFrozenOptimizations(JitTestCase):
res3 = torch._C._nn.linear(in_tensor2, self.w2, self.b2)
res4 = torch._C._nn.linear(in_tensor1, self.w2, self.b1)
else:
raise AssertionError()
raise AssertionError
res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b1)
return res1, res2, res3, res4

View File

@ -2871,7 +2871,7 @@ class TestScriptList(JitTestCase):
def __next__(self):
if self.value == limit: # noqa: F821
raise StopIteration()
raise StopIteration
ret = self.value
self.value += 1

View File

@ -54,7 +54,7 @@ class ClosuresTest(TestCase):
torch._lazy.add_step_closure(closure)
torch._lazy.mark_step()
raise AssertionError() # Should not reach here
raise AssertionError # Should not reach here
except RuntimeError as e:
assert flag.is_set(), "Should have caught exception from closure"
@ -79,7 +79,7 @@ class ClosuresTest(TestCase):
torch._lazy.add_step_closure(closure2, run_async=True)
torch._lazy.mark_step()
raise AssertionError() # Should not reach here
raise AssertionError # Should not reach here
except RuntimeError as e:
# Should have caught exception from closure1
pass

View File

@ -283,7 +283,7 @@ class TestOperators(common_utils.TestCase):
def symbolic(g, x):
# The inside of this function should never be invoked, because
# we will fail due to an argument mismatch first.
raise AssertionError()
raise AssertionError
@staticmethod
def forward(ctx, x, y):

View File

@ -154,7 +154,7 @@ class Errors:
NB: It is an error to "fail" without having added any errors to
the error context.
"""
raise self.exc_class()
raise self.exc_class
def failWith(self, msg):
"""
@ -489,7 +489,7 @@ def verify(
errs.requireEqual(
proto_bytes.getvalue(), alt_proto_bytes.getvalue()
)
raise AssertionError()
raise AssertionError
# TODO: test that the traced model also returns the same thing...
run_helper(torch_out, args, remained_onnx_input_idx)

View File

@ -98,7 +98,7 @@ class TestImporter(PackageTestCase):
self._whichmodule_return = whichmodule_return
def import_module(self, module_name):
raise NotImplementedError()
raise NotImplementedError
def whichmodule(self, obj, name):
return self._whichmodule_return

View File

@ -1970,7 +1970,7 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
try:
with cm:
x.add(y)
raise ValueError()
raise ValueError
x.relu()
except ValueError:
pass

View File

@ -509,7 +509,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def baz(x: Tensor) -> Tensor:
raise NotImplementedError()
raise NotImplementedError
def test_unsupported_schemas(self):
with self.assertRaisesRegex(ValueError, "only supports functional"):
@ -670,35 +670,35 @@ class TestCustomOp(CustomOpTestCaseBase):
with self.assertRaisesRegex(ValueError, "varargs"):
def foo(*args):
raise NotImplementedError()
raise NotImplementedError
infer_schema(foo)
with self.assertRaisesRegex(ValueError, "varkwargs"):
def foo(**kwargs):
raise NotImplementedError()
raise NotImplementedError
infer_schema(foo)
with self.assertRaisesRegex(ValueError, "must have a type annotation"):
def foo(x):
raise NotImplementedError()
raise NotImplementedError
infer_schema(foo)
with self.assertRaisesRegex(ValueError, "unsupported"):
def foo(x: Tensor) -> Tuple[Tensor, ...]:
raise NotImplementedError()
raise NotImplementedError
infer_schema(foo)
with self.assertRaisesRegex(ValueError, "can be mutated"):
def foo(x: Tensor, y: int) -> Tensor:
raise NotImplementedError()
raise NotImplementedError
infer_schema(foo, mutates_args={"y"})
@ -752,7 +752,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{self.test_ns}::foo")
def foo(x: Tensor) -> typ:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{self.test_ns}::foo")
def foo_impl(x: Tensor) -> typ:
@ -771,7 +771,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{self.test_ns}::foo")
def foo(x: Tensor) -> Tuple[typ, typ]:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{self.test_ns}::foo")
def foo_impl(x: Tensor) -> Tuple[typ, typ]:
@ -789,7 +789,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: Tensor, y: typ) -> Tensor:
raise NotImplementedError()
raise NotImplementedError
yeet = None
@ -823,7 +823,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{self.test_ns}::foo")
def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
called = 0
@ -847,7 +847,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
raise NotImplementedError()
raise NotImplementedError
del foo
@ -855,7 +855,7 @@ class TestCustomOp(CustomOpTestCaseBase):
# int[N] in Dispatcher is a bit wild, so we don't try to support it.
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
raise NotImplementedError()
raise NotImplementedError
del foo
@ -863,7 +863,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: Tensor, y: Callable) -> Tensor:
raise NotImplementedError()
raise NotImplementedError
del foo
@ -910,7 +910,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{ns}::foo2")
def foo2(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
def test_private_ctor(self):
with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"):
@ -919,7 +919,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_lifetime(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo")
@ -928,7 +928,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
raise NotImplementedError()
raise NotImplementedError
# Unless we delete the original op.
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
@ -936,14 +936,14 @@ class TestCustomOp(CustomOpTestCaseBase):
# Smoke test
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
raise NotImplementedError()
raise NotImplementedError
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
def test_autograd_notimplemented(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
raise NotImplementedError()
raise NotImplementedError
x = torch.randn(3, requires_grad=True)
op = self.get_op(f"{self.test_ns}::foo")
@ -954,7 +954,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
x = torch.randn(3, requires_grad=True)
y = torch.randn(3)
@ -966,7 +966,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
x = torch.randn(3, requires_grad=True)
y = torch.randn(3)
@ -978,7 +978,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_autograd_notimplemented_gradmode(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(x, y):
@ -994,7 +994,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_impl_cpu(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
def foo_cpu(x):
@ -1008,7 +1008,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_impl_invalid_devices(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
def foo_impl(x):
return x.sin()
@ -1033,7 +1033,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_partially_registered(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(x):
@ -1054,7 +1054,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_save_for_backward_inputs_are_namedtuple(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(x):
@ -1084,7 +1084,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_returns_dict(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(x):
@ -1107,7 +1107,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_dict_invalid_keys(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(x):
@ -1130,7 +1130,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_dict_grad_for_nontensor(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(x, dim):
@ -1153,7 +1153,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_dict_requires_keys_for_input_tensors(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(x, y):
@ -1176,7 +1176,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_dict_requires_keys_for_input_optional_tensors(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(x, y):
@ -1199,7 +1199,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_grads_are_tensor_or_none(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(x):
@ -1222,7 +1222,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(xs):
@ -1245,7 +1245,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(xs):
@ -1268,7 +1268,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_tensorlist_input_requires_list_grads(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(xs):
@ -1291,7 +1291,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_output_differentiability_type(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
@ -1304,7 +1304,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_output_differentiability_numel(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
raise NotImplementedError
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
@ -1317,7 +1317,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_output_differentiability_tensorlist(self):
@custom_ops.custom_op(f"{self.test_ns}::foo")
def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{self.test_ns}::foo")
def foo_impl(x):
@ -1343,7 +1343,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_backward_output_differentiability_non_tensor(self):
@custom_ops.custom_op(f"{self.test_ns}::foo")
def foo(x: Tensor) -> Tuple[Tensor, int]:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{self.test_ns}::foo")
def foo_impl(x):
@ -1368,7 +1368,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_impl_separate(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
def foo_cpu(x):
@ -1392,7 +1392,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_impl_multiple(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
def foo_impl(x):
@ -1422,7 +1422,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_impl_meta(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x, dim):
@ -1438,7 +1438,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_duplicate_impl(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x, dim):
@ -1457,7 +1457,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_new_data_dependent_symint(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x):
@ -1483,7 +1483,7 @@ class TestCustomOp(CustomOpTestCaseBase):
# this one is just a sanity check.
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x):
@ -1497,7 +1497,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def test_not_implemented_error(self):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
x = torch.randn(3)
op = self.get_op(f"{self.test_ns}::foo")
@ -1510,7 +1510,7 @@ class TestCustomOp(CustomOpTestCaseBase):
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar")
def bar(sizes: Sequence[int]) -> torch.Tensor:
raise NotImplementedError()
raise NotImplementedError
op = self.get_op(f"{self.test_ns}::bar")
with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
@ -2021,7 +2021,7 @@ class MiniOpTest(CustomOpTestCaseBase):
@staticmethod
def backward(ctx, grad):
raise NotImplementedError()
raise NotImplementedError
def autograd_impl(x):
return Op.apply(x)

View File

@ -492,7 +492,7 @@ class TestBinary(TestCase):
mt1 = masked_tensor(data1, mask1)
try:
fn(mt0, mt1)
raise AssertionError()
raise AssertionError
except ValueError as e:
assert (
"Input masks must match. If you need support for this, please open an issue on Github."

View File

@ -6603,7 +6603,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
elif weight_layout == torch.sparse_coo:
module.weight = nn.Parameter(module.weight.to_sparse_coo())
else:
raise AssertionError()
raise AssertionError
inp = torch.randn(4, requires_grad=True, device=device)
res = module(inp)

View File

@ -1206,7 +1206,7 @@ class TestTorchFunctionMode(TestCase):
class A(TorchFunctionMode):
def __torch_function__(self, *args, **kwargs):
raise ErrorA()
raise ErrorA
with self.assertRaises(ErrorA):
with A():
@ -1218,7 +1218,7 @@ class TestTorchFunctionMode(TestCase):
class A(TorchFunctionMode):
def __torch_function__(self, *args, **kwargs):
raise ErrorA()
raise ErrorA
x = A()
with self.assertRaises(ErrorA):

View File

@ -1238,7 +1238,7 @@ $3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
class AMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if func.__name__ == 'randn.default':
raise RuntimeError()
raise RuntimeError
return A(torch.zeros(()))
with AMode():
@ -1254,7 +1254,7 @@ $3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA()
raise ErrorA
x = A()
with self.assertRaises(ErrorA):

View File

@ -1081,7 +1081,7 @@ class TestTraceback(TestCase):
source = '''\
def f(x):
def g(x):
raise RuntimeError() # HEYA
raise RuntimeError # HEYA
x = x * 3
return g(x) + 1
@ -1099,7 +1099,7 @@ def f(x):
def test_format_traceback_short(self):
try:
raise RuntimeError()
raise RuntimeError
except RuntimeError as e:
self.assertRegex(format_traceback_short(e.__traceback__), r'.*test_utils.py:\d+ in test_format_traceback_short')

View File

@ -525,7 +525,7 @@ class WeakKeyDictionaryTestCase(TestCase):
return self
def __next__(self):
raise Exc()
raise Exc
self.assertRaises(Exc, d.update, badseq())
@ -866,7 +866,7 @@ class WeakKeyDictionaryScriptObjectTestCase(TestCase):
return self
def __next__(self):
raise Exc()
raise Exc
self.assertRaises(Exc, d.update, badseq())

View File

@ -1136,14 +1136,14 @@ class TestCreation(TestCase):
return 1
def __getitem__(self, index):
raise ValueError()
raise ValueError
class Map:
def __len__(self):
return 1
def __getitem__(self, index):
raise KeyError()
raise KeyError
a = np.array([Map()])
assert_(a.shape == (1,))
@ -1160,7 +1160,7 @@ class TestCreation(TestCase):
if ind in [0, 1]:
return ind
else:
raise IndexError()
raise IndexError
d = np.array([Point2(), Point2(), Point2()])
assert_equal(d.dtype, np.dtype(object))

View File

@ -342,7 +342,7 @@ class TestFFT1D(TestCase):
Y_res = fft(Y, axes=ax)
assert_allclose(X_res, Y_res, atol=_tol, rtol=_tol)
else:
raise ValueError()
raise ValueError
@skipif(IS_WASM, reason="Cannot start thread")

View File

@ -1435,7 +1435,7 @@ class TestVectorize(TestCase):
try:
vectorize(random.randrange) # Should succeed
except Exception:
raise AssertionError() # noqa: TRY200
raise AssertionError # noqa: TRY200
def test_keywords2_ticket_2100(self):
# Test kwarg support: enhancement ticket 2100

View File

@ -1252,7 +1252,7 @@ def emit_body(
if a.name == derivative_var_name:
break
else:
raise AssertionError()
raise AssertionError
return f"grad_fn->should_compute_output({edge_off})"
if is_inplace_foreach:

View File

@ -61,7 +61,7 @@ def custom_op(qualname, func_or_schema=None):
>>> # we will infer the types of the inputs and outputs.
>>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>> raise NotImplementedError()
>>> raise NotImplementedError
>>>
>>> # The custom op is now accessible via the torch.ops module:
>>> torch.ops.mylibrary.numpy_sin
@ -143,7 +143,7 @@ def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
>>> # we will infer the types of the inputs and outputs.
>>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
>>> def numpy_cos(x: Tensor) -> Tensor:
>>> raise NotImplementedError()
>>> raise NotImplementedError
>>>
>>> # The custom op is now accessible via the torch.ops module:
>>> torch.ops.mylibrary.numpy_cos
@ -207,7 +207,7 @@ def impl_abstract(qualname, *, func=None):
>>> # Example 1: an operator without data-dependent output shape
>>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
>>> raise NotImplementedError()
>>> raise NotImplementedError
>>>
>>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
>>> def custom_linear_abstract(x, weight):

View File

@ -129,7 +129,7 @@ class TestingOnlyCompileError(Exception):
def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
for node in gm.graph.nodes:
if node.target == torch.relu:
raise ReluCompileError()
raise ReluCompileError
return gm
@ -165,7 +165,7 @@ def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs
return gm
for t in example_inputs:
if not t.is_leaf:
raise TestingOnlyCompileError()
raise TestingOnlyCompileError
return gm

View File

@ -39,7 +39,7 @@ class DeviceInterface(metaclass=DeviceInterfaceMeta):
class device:
def __new__(cls, device: _device_t):
raise NotImplementedError()
raise NotImplementedError
class Worker:
"""
@ -51,71 +51,71 @@ class DeviceInterface(metaclass=DeviceInterfaceMeta):
@staticmethod
def set_device(device: int):
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def current_device() -> int:
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def get_device_properties(device: _device_t = None):
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def current_device():
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def set_device(device: _device_t):
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def maybe_exchange_device(device: int) -> int:
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def exchange_device(device: int) -> int:
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def device_count():
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def is_available() -> bool:
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def stream(stream: torch.Stream):
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def current_stream():
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def set_stream(stream: torch.Stream):
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def _set_stream_by_id(stream_id: int, device_index: int, device_type: int):
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def get_raw_stream():
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def synchronize(device: _device_t = None):
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def get_device_properties(device: _device_t = None):
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def get_compute_capability(device: _device_t = None):
raise NotImplementedError()
raise NotImplementedError
class DeviceGuard:

View File

@ -215,7 +215,7 @@ class EphemeralSource(Source):
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
def make_guard(self):
raise NotImplementedError()
raise NotImplementedError
def is_ephemeral(self):
return True
@ -277,7 +277,7 @@ class NegateSource(ChainedSource):
assert self.base is not None
def reconstruct(self, codegen):
raise NotImplementedError()
raise NotImplementedError
def guard_source(self):
return self.base.guard_source()
@ -516,7 +516,7 @@ class ConstantSource(Source):
return self.source_name
def make_guard(self, fn):
raise NotImplementedError()
raise NotImplementedError
@dataclasses.dataclass(frozen=True)

View File

@ -1242,7 +1242,7 @@ class InstructionTranslatorBase(
if (
isinstance(val, BuiltinVariable) and val.fn is StopIteration
) or isinstance(val, variables.StopIterationVariable):
raise exc.UserStopIteration()
raise exc.UserStopIteration
unimplemented(f"raise {exc}")
else:
unimplemented("raise ... from ...")
@ -2231,7 +2231,7 @@ class InstructionTranslator(InstructionTranslatorBase):
return self.f_locals[source.local_name]
if isinstance(source, GlobalSource):
return self.f_globals[source.global_name]
raise KeyError()
raise KeyError
def run(self):
super().run()
@ -2388,7 +2388,7 @@ class InstructionTranslator(InstructionTranslatorBase):
else create_instruction("RETURN_CONST", argval=inst.argval)
)
self.output.add_output_instructions([return_inst])
raise ReturnValueOp()
raise ReturnValueOp
def RETURN_VALUE(self, inst):
self._return(inst)
@ -2637,7 +2637,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
self.output.root_tx.mutated_closure_cell_contents.add(
maybe_cell.source.name()
)
raise exc.UnspecializeRestartAnalysis()
raise exc.UnspecializeRestartAnalysis
unimplemented("write to __closure__ while inlining")
def LOAD_DEREF(self, inst):
@ -2676,12 +2676,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
def RETURN_VALUE(self, inst):
self.symbolic_result = self.pop() # type: ignore[assignment]
self.instruction_pointer = None
raise ReturnValueOp()
raise ReturnValueOp
def RETURN_CONST(self, inst):
self.symbolic_result = self._load_const(inst)
self.instruction_pointer = None
raise ReturnValueOp()
raise ReturnValueOp
class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):

View File

@ -219,17 +219,17 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def make_guard(self, fn):
if self.source:
return self.source.make_guard(fn)
raise NotImplementedError()
raise NotImplementedError
def const_getattr(self, tx, name: str) -> Any:
"""getattr(self, name) returning a python constant"""
raise NotImplementedError()
raise NotImplementedError
def var_getattr(self, tx, name: str) -> "VariableTracker":
"""getattr(self, name) returning a new variable"""
value = self.const_getattr(tx, name)
if not variables.ConstantVariable.is_literal(value):
raise NotImplementedError()
raise NotImplementedError
source = None
if self.source:
source = AttrSource(self.source, name)
@ -257,7 +257,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
return None
def reconstruct(self, codegen):
raise NotImplementedError()
raise NotImplementedError
def can_reconstruct(self, tx):
"""If it is possible to reconstruct the Python object this
@ -273,7 +273,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
return False
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
raise NotImplementedError()
raise NotImplementedError
def has_unpack_var_sequence(self, tx) -> bool:
try:

View File

@ -122,7 +122,7 @@ class ConstantVariable(VariableTracker):
)
member = getattr(self.value, name)
if callable(member):
raise NotImplementedError()
raise NotImplementedError
return member
def call_method(
@ -212,5 +212,5 @@ class EnumVariable(VariableTracker):
def const_getattr(self, tx, name):
member = getattr(self.value, name)
if callable(member):
raise NotImplementedError()
raise NotImplementedError
return member

View File

@ -420,7 +420,7 @@ class DictView(VariableTracker):
def view_items_vt(self):
# Returns an iterable of the unpacked items
# Implement in the subclasses
raise NotImplementedError()
raise NotImplementedError
def unpack_var_sequence(self, tx):
def unwrap(x):
@ -615,7 +615,7 @@ class DataClassVariable(ConstDictVariable):
assert self.is_matching_cls(user_cls)
def as_proxy(self):
raise NotImplementedError()
raise NotImplementedError
def reconstruct(self, codegen):
codegen.extend_output([codegen._create_load_const(self.user_cls)])
@ -737,14 +737,14 @@ class CustomizedDictVariable(ConstDictVariable):
# called from builder.py
@classmethod
def wrap(cls, builder, obj):
raise NotImplementedError()
raise NotImplementedError
def __init__(self, items, user_cls, **options):
super().__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)
def as_proxy(self):
raise NotImplementedError()
raise NotImplementedError
# 'RETURN_VALUE triggered compile'
# called from torch/_dynamo/codegen.py

View File

@ -439,7 +439,7 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
def get_function(self):
if self.closure:
raise NotImplementedError()
raise NotImplementedError
func = types.FunctionType(
self.code.as_python_constant(),
self.f_globals,

View File

@ -160,7 +160,7 @@ class RangeVariable(BaseListVariable):
elif len(items_to_map) == 3:
start, stop, step = items_to_map
else:
raise AssertionError()
raise AssertionError
assert stop is not None
super().__init__([start, stop, step], **kwargs)
@ -592,7 +592,7 @@ class SliceVariable(BaseListVariable):
elif len(items_to_map) == 3:
start, stop, step = items_to_map
else:
raise AssertionError()
raise AssertionError
if isinstance(start, variables.TensorVariable) or isinstance(
stop, variables.TensorVariable
@ -644,7 +644,7 @@ class ListIteratorVariable(VariableTracker):
assert self.mutable_local
old_index = self.index
if old_index >= len(self.items):
raise StopIteration()
raise StopIteration
tx.output.side_effects.mutation(self)
self.index += 1
return self.items[old_index]
@ -665,7 +665,7 @@ class ListIteratorVariable(VariableTracker):
def as_python_constant(self):
if self.index > 0:
raise NotImplementedError()
raise NotImplementedError
return iter([x.as_python_constant() for x in self.items])
def unpack_var_sequence(self, tx):
@ -748,7 +748,7 @@ class RestrictedListSubclassVariable(ListVariable):
return [x.as_proxy() for x in self.items]
def as_python_constant(self):
raise NotImplementedError()
raise NotImplementedError
def is_python_constant(self):
return False

View File

@ -592,13 +592,13 @@ class GetAttrVariable(VariableTracker):
def const_getattr(self, tx, name):
if not isinstance(self.obj, variables.NNModuleVariable):
raise NotImplementedError()
raise NotImplementedError
step1 = tx.output.get_submodule(self.obj.module_key)
if self.name not in step1.__dict__:
raise NotImplementedError()
raise NotImplementedError
step2 = inspect.getattr_static(step1, self.name)
if name not in step2.__dict__:
raise NotImplementedError()
raise NotImplementedError
return inspect.getattr_static(step2, name)
def reconstruct(self, codegen):

View File

@ -153,7 +153,7 @@ class NNModuleVariable(VariableTracker):
# Mark the class dynamic unless its module initialization
if tx.f_code.co_name != "__init__":
GenerationTracker.mark_class_dynamic(type(mod))
raise UnspecializeRestartAnalysis()
raise UnspecializeRestartAnalysis
def _custom_getattr_fallback(self, base, tx, name, options):
"""Check for a __getattr__ and handle it specially if it is implemented"""

View File

@ -157,7 +157,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
):
return self.value.param_groups[arg.source.index]
raise ArgMappingException()
raise ArgMappingException
new_args = [map_arg(arg) for arg in args]
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}

View File

@ -222,7 +222,7 @@ class TensorVariable(VariableTracker):
return SourcelessBuilder.create(tx, example_value)
if not (self.source and self.source.subguards_allowed()):
raise NotImplementedError()
raise NotImplementedError
# For local source, we associate the real value. We use this real value
# for implementing getattr fallthrough on the variable tracker base class.
@ -238,23 +238,23 @@ class TensorVariable(VariableTracker):
# Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
_input_associated_real_value = eval(self.source.name(), scope)
except Exception as exc:
raise NotImplementedError() from exc
raise NotImplementedError from exc
if _input_associated_real_value is None:
raise NotImplementedError()
raise NotImplementedError
if object_has_getattribute(_input_associated_real_value):
raise NotImplementedError()
raise NotImplementedError
if get_custom_getattr(_input_associated_real_value):
raise NotImplementedError()
raise NotImplementedError
real_value = getattr(_input_associated_real_value, name)
if callable(real_value):
# Callables have more nuanced handling, and we should let the existing system delegate here.
# Raising was past behavior and so should always be sound to fall back.
# Note - at a certain point we may want to handle
raise NotImplementedError()
raise NotImplementedError
from ..guards import GuardBuilder
from .builder import VariableBuilder
@ -391,7 +391,7 @@ class TensorVariable(VariableTracker):
result = self.dynamic_getattr(tx, name)
if result is None:
raise NotImplementedError()
raise NotImplementedError
return result
def has_unpack_var_sequence(self, tx):
@ -1090,7 +1090,7 @@ class NumpyNdarrayVariable(TensorVariable):
elif name in ["__version__"]:
unimplemented("delegate np.__version__ to NumPy")
if result is None:
raise NotImplementedError()
raise NotImplementedError
return result
@staticmethod

View File

@ -72,7 +72,7 @@ class FuncTorchInterpreter(ABC):
return self._cptr.key()
def get_state(self):
raise NotImplementedError()
raise NotImplementedError
def check_state(self, state):
return state == self.get_state()

View File

@ -797,17 +797,17 @@ class Source:
return False
def reconstruct(self, codegen):
raise NotImplementedError()
raise NotImplementedError
def guard_source(self) -> GuardSource:
raise NotImplementedError()
raise NotImplementedError
def name(self) -> str:
raise NotImplementedError()
raise NotImplementedError
def make_guard(self, fn) -> Guard:
if self.guard_source() is GuardSource.CONSTANT:
raise NotImplementedError()
raise NotImplementedError
return Guard(self, fn)
def is_nn_module(self) -> bool:

View File

@ -422,7 +422,7 @@ class BenchmarkRequest:
def make_run_fn(
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
) -> Callable[[], None]:
raise NotImplementedError()
raise NotImplementedError
def cleanup_run_fn(self) -> None:
pass

View File

@ -446,7 +446,7 @@ def _reduce_tensor(t):
# TODO: These tensors don't currently pickle, so we can't cache a
# compiled graph containing them. Just fail now. If mkldnn tensors
# get pickling support, we can remove this.
raise BypassFxGraphCache()
raise BypassFxGraphCache
# Very large tensors could be expensive to copy to cpu and hash. Let's
# at least report if we find slowness.
@ -598,7 +598,7 @@ class FxGraphHashDetails:
# Some configs options are callables, e.g., post_grad_custom_pre_pass,
# and may not pickle.
log.debug("Can't pickle inductor config: %s", e)
raise BypassFxGraphCache() from e
raise BypassFxGraphCache from e
def debug_str(self) -> str:
"""
@ -843,19 +843,19 @@ class FxGraphCache:
"""
# Freezing can embed constants that wouldn't be static across runs.
if config.freezing or config.aot_inductor.use_runtime_constant_folding:
raise BypassFxGraphCache()
raise BypassFxGraphCache
# The treatment of guards in the caching implementation requires that
# we have a shape env.
if FxGraphCache._get_shape_env() is None:
log.debug("fx graph cache no shape env")
raise BypassFxGraphCache()
raise BypassFxGraphCache
# HigherOrderOperators should be handled on a case-by-case basis.
# Currently, we just skip caching if we have any.
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.HigherOrderOperator):
raise BypassFxGraphCache()
raise BypassFxGraphCache
@staticmethod
def load(
@ -990,7 +990,7 @@ def cpp_compiler_search(search: str) -> str:
return cxx
except (subprocess.SubprocessError, FileNotFoundError, ImportError):
continue
raise exc.InvalidCxxCompiler()
raise exc.InvalidCxxCompiler
def install_gcc_via_conda() -> str:
@ -2745,7 +2745,7 @@ def _worker_compile_triton(
class CodeCacheFuture:
def result(self):
raise NotImplementedError()
raise NotImplementedError
class TritonFuture(CodeCacheFuture):

View File

@ -88,16 +88,16 @@ device_codegens: Dict[str, DeviceCodegen] = {}
class DeviceOpOverrides:
def import_get_raw_stream_as(self, name):
raise NotImplementedError()
raise NotImplementedError
def set_device(self, device_idx):
raise NotImplementedError()
raise NotImplementedError
def synchronize(self):
raise NotImplementedError()
raise NotImplementedError
def device_guard(self, device_idx):
raise NotImplementedError()
raise NotImplementedError
device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
@ -1368,7 +1368,7 @@ class Kernel(CodeGen):
self.cse = cse
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
raise NotImplementedError()
raise NotImplementedError
def indirect_load(self, name: str, index: sympy.Expr):
"""A load the depends on an index we have read"""
@ -1381,12 +1381,12 @@ class Kernel(CodeGen):
self.loads = prior
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
raise NotImplementedError()
raise NotImplementedError
def store(
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
) -> None:
raise NotImplementedError()
raise NotImplementedError
def reduction(
self,
@ -1395,7 +1395,7 @@ class Kernel(CodeGen):
reduction_type: ReductionType,
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
raise NotImplementedError()
raise NotImplementedError
def scan(
self,
@ -1405,7 +1405,7 @@ class Kernel(CodeGen):
],
values: Tuple[CSEVariable, ...],
) -> Tuple[CSEVariable, ...]:
raise NotImplementedError()
raise NotImplementedError
def bucketize(
self,
@ -1418,11 +1418,11 @@ class Kernel(CodeGen):
"""
See [Note: Inductor bucketize op]
"""
raise NotImplementedError()
raise NotImplementedError
@property
def assert_function(self) -> str:
raise NotImplementedError()
raise NotImplementedError
def indirect_assert(self, var, lower, upper, mask=None):
if lower and upper:
@ -1444,7 +1444,7 @@ class Kernel(CodeGen):
return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
def index_to_str(self, index: sympy.Expr) -> str:
raise NotImplementedError()
raise NotImplementedError
def __enter__(self):
# TODO: hoist this to top level
@ -1737,4 +1737,4 @@ class KernelTemplate:
Generates a ChoiceCaller instance from the given arguments.
"""
raise NotImplementedError()
raise NotImplementedError

View File

@ -2659,7 +2659,7 @@ class CppVecKernel(CppKernel):
mean, m2, weight = reduction_project(reduction_type, next_value)
return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
else:
raise NotImplementedError()
raise NotImplementedError
def indirect_assert(self, var, lower, upper, mask=None):
assert not mask, "do not support mask in indirect_indexing assertion"

View File

@ -197,7 +197,7 @@ class CutlassEVTEpilogueTypeFormatter:
return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, ElementAcc, ElementAcc, RoundStyle>,{a}, {const_zero}>" # noqa: B950
def reduction(self, dtype, src_dtype, reduction_type, value):
raise CUTLASSEVTOpNotImplementedError()
raise CUTLASSEVTOpNotImplementedError
# Add more ops here...
def getvalue(self, result) -> str:
@ -354,7 +354,7 @@ class CutlassEVTEpilogueArgumentFormatter:
return a
def reduction(self, dtype, src_dtype, reduction_type, value):
raise CUTLASSEVTOpNotImplementedError()
raise CUTLASSEVTOpNotImplementedError
def getvalue(self, result) -> str:
return "{" + str(result) + "}"

View File

@ -134,15 +134,15 @@ class AllocationTreeNode:
def get_live_ranges(self) -> LiveRanges:
"""Aggregate LiveRanges for all objects below this in tree"""
raise NotImplementedError()
raise NotImplementedError
def get_size_hint(self) -> int:
"""Number of bytes used for example inputs"""
raise NotImplementedError()
raise NotImplementedError
def get_symbolic_size(self) -> sympy.Expr:
"""Number of bytes needed at runtime"""
raise NotImplementedError()
raise NotImplementedError
def finalize(self, pool, offset) -> AllocationTreeNode:
"""Called after all allocations have been made"""

View File

@ -1468,7 +1468,7 @@ class TritonKernel(Kernel):
def add_range(i, expr):
expr = sv.simplify(expr)
if not sv.statically_known_multiple_of(remaining[i], expr):
raise CantSplit()
raise CantSplit
# guard on the last item out
remaining[i] = FloorDiv(remaining[i], expr)
new_ranges[i].append(expr)
@ -1501,7 +1501,7 @@ class TritonKernel(Kernel):
if not sv.statically_known_multiple_of(
size, remaining[current_group]
):
raise CantSplit()
raise CantSplit
size1 = remaining[current_group]
size2 = FloorDiv(size, remaining[current_group])
return_getters.append(

View File

@ -1346,7 +1346,7 @@ class WrapperCodeGen(CodeGen):
self.lines.append(LineContext(ctx))
def val_to_cpp_arg_str(self, type_, val) -> str:
raise NotImplementedError()
raise NotImplementedError
def val_to_arg_str(self, s):
from torch.utils._triton import dtype_to_string, has_triton_package

View File

@ -954,10 +954,10 @@ class GraphLowering(torch.fx.Interpreter):
return self.add_tensor_constant(value, target)
def call_module(self, target, args, kwargs):
raise AssertionError()
raise AssertionError
def call_method(self, target, args, kwargs):
raise AssertionError()
raise AssertionError
def output(self, target, args, kwargs):
result = super().output(target, args, kwargs)

View File

@ -2252,7 +2252,7 @@ class View(GenericView):
size_old = size_old * modulus
V.graph.sizevars.guard_equals(size_new, size_old)
else:
raise AssertionError()
raise AssertionError
while stack_old:
size_old = stack_old.pop()
@ -2818,7 +2818,7 @@ class FlexibleLayout(Layout):
"stride_ordered_for_memory_format, unsuppored memory_format: %s",
memory_format,
)
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def same_ordered(sizes, stride):
@ -3666,16 +3666,16 @@ class ChoiceCaller:
return do_bench(lambda: algo(*args, out=out))
def call_name(self) -> str:
raise NotImplementedError()
raise NotImplementedError
def to_callable(self):
raise NotImplementedError()
raise NotImplementedError
def hash_key(self) -> str:
raise NotImplementedError()
raise NotImplementedError
def output_node(self) -> "TensorBox":
raise NotImplementedError()
raise NotImplementedError
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
"""Information returned here is logged to the autotune log file when that is enabled."""
@ -3684,7 +3684,7 @@ class ChoiceCaller:
class TritonTemplateCallerBase(ChoiceCaller):
def get_make_kernel_render(self) -> Any:
raise NotImplementedError()
raise NotImplementedError
class MultiTemplateBuffer(TritonTemplateBuffer):
@ -4033,7 +4033,7 @@ class ExternKernel(InputsKernel):
wrapper.writeline(origin_str)
def codegen(self, wrapper):
raise NotImplementedError()
raise NotImplementedError
def get_kernel_name(self):
return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name
@ -4157,7 +4157,7 @@ class ExternKernel(InputsKernel):
offset,
index,
)
raise NotImplementedError()
raise NotImplementedError
return ReinterpretView(
data=x.data,

View File

@ -1629,7 +1629,7 @@ def bernoulli_p(x, *args):
# This shouldn't be called in general
@register_lowering(aten._foobar)
def _foobar(_):
raise AssertionError()
raise AssertionError
@functools.lru_cache(1)

View File

@ -218,7 +218,7 @@ class PatternExpr:
def _match(
self, node: torch.fx.Node, ctx: MatchContext
) -> Union[Match, FailedMatch]:
raise NotImplementedError()
raise NotImplementedError
def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]:
try:
@ -361,7 +361,7 @@ class _TargetExpr(PatternExpr):
return isinstance(self.users, Multiple) or self.users > 1
def find_anchor_nodes(self, ctx: MatchContext, searched):
raise NotImplementedError()
raise NotImplementedError
def _match_fns(self, node: torch.fx.Node):
return (
@ -803,7 +803,7 @@ class PatternEntry:
extra_check: Callable[[Match], bool]
def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
raise NotImplementedError()
raise NotImplementedError
def register(self, pass_dicts, target=None, prepend=False):
if target is None:
@ -1507,7 +1507,7 @@ class PatternMatcherPass:
def _not_implemented(*args, **kwargs) -> NoReturn:
raise NotImplementedError()
raise NotImplementedError
def fx_to_pattern(

View File

@ -2514,13 +2514,13 @@ class BaseScheduling:
"""
Check whether node1 and node2 can be vertically fused or not.
"""
raise NotImplementedError()
raise NotImplementedError
def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
"""
Check whether node1 and node2 can be horizontally fused or not.
"""
raise NotImplementedError()
raise NotImplementedError
def fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
"""
@ -2535,7 +2535,7 @@ class BaseScheduling:
"""
Process the iteration sizes in case a transformation needs to be applied.
"""
raise NotImplementedError()
raise NotImplementedError
def codegen_template(
self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode]
@ -2546,19 +2546,19 @@ class BaseScheduling:
This function is only available for triton now. If the third-party backend behaves as a sub-class
of TritonScheduling, it can override it or reuse it.
"""
raise NotImplementedError()
raise NotImplementedError
def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]):
"""
Generate a kernel given a list of pre-fused nodes.
"""
raise NotImplementedError()
raise NotImplementedError
def codegen_sync(self):
"""
Generate synchronization code for the kernel. This method depends on the hardware characteristics.
"""
raise NotImplementedError()
raise NotImplementedError
def ready_to_flush(self) -> bool:
"""
@ -2571,14 +2571,14 @@ class BaseScheduling:
"""
Flush the generated kernel and python wrapper code to the source code file.
"""
raise NotImplementedError()
raise NotImplementedError
def benchmark_fused_nodes(self, nodes):
"""
Benchmark fused list of nodes and return the execution time
in milliseconds on randomly generated inputs.
"""
raise NotImplementedError()
raise NotImplementedError
def get_fusion_pair_priority(self, node1, node2) -> int:
"""

View File

@ -949,11 +949,11 @@ class DeferredLineBase:
def __call__(self) -> Optional[str]:
"""Returns either self.line or None to indicate the line has been 'unwritten'"""
raise NotImplementedError()
raise NotImplementedError
def _new_line(self, line: str) -> DeferredLineBase:
"""Returns a new deferred line with the same condition"""
raise NotImplementedError()
raise NotImplementedError
def with_prefix(self, prefix):
return self._new_line(f"{prefix}{self.line}")

View File

@ -217,7 +217,7 @@ def _split_helper_int(tensor, indices_or_sections, axis, strict=False):
l, n = tensor.shape[axis], indices_or_sections
if n <= 0:
raise ValueError()
raise ValueError
if l % n == 0:
num, sz = n, l // n

View File

@ -512,7 +512,7 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
if like is not None:
raise NotImplementedError("'like' parameter is not supported.")
if order != "K":
raise NotImplementedError()
raise NotImplementedError
# a happy path
if (

View File

@ -174,7 +174,7 @@ def maybe_copy_to(out, result, promote_scalar_result=False):
maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result)
)
else:
raise AssertionError() # We should never hit this path
raise AssertionError # We should never hit this path
def wrap_tensors(result):

View File

@ -85,7 +85,7 @@ class OperatorBase:
self.functorch_table = {}
def __call__(self, *args, **kwargs):
raise NotImplementedError()
raise NotImplementedError
def has_kernel_for_dispatch_key(self, k):
return k in self.py_kernels
@ -165,7 +165,7 @@ class OperatorBase:
return fn
def name(self):
raise NotImplementedError()
raise NotImplementedError
is_included_in_alias = torch._C._dispatch_is_included_in_alias

View File

@ -6,27 +6,27 @@ class _StreamBase(ABC):
@abstractmethod
def wait_event(self, event):
raise NotImplementedError()
raise NotImplementedError
@abstractmethod
def wait_stream(self, stream):
raise NotImplementedError()
raise NotImplementedError
@abstractmethod
def record_event(self, event=None):
raise NotImplementedError()
raise NotImplementedError
@abstractmethod
def query(self):
raise NotImplementedError()
raise NotImplementedError
@abstractmethod
def synchronize(self):
raise NotImplementedError()
raise NotImplementedError
@abstractmethod
def __eq__(self, stream):
raise NotImplementedError()
raise NotImplementedError
class _EventBase(ABC):
@ -34,12 +34,12 @@ class _EventBase(ABC):
@abstractmethod
def wait(self, stream=None):
raise NotImplementedError()
raise NotImplementedError
@abstractmethod
def query(self):
raise NotImplementedError()
raise NotImplementedError
@abstractmethod
def synchronize(self):
raise NotImplementedError()
raise NotImplementedError

View File

@ -354,7 +354,7 @@ def trace(data):
elif e['action'] == 'oom':
size = e['size']
free = e['device_free']
out.write(f'raise OutOfMemoryError() # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
out.write(f'raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
else:
out.write(f'{e}\n')
out.write(f"TOTAL MEM: {Bytes(count)}")

View File

@ -465,7 +465,7 @@ class FSDPParam:
return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)]
elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
if hasattr(self._sharded_local_tensor, "fsdp_pre_all_gather"):
raise NotImplementedError()
raise NotImplementedError
all_gather_input = _to_dtype_if_needed(
cast(torch.Tensor, self._sharded_post_forward_param_data),
self.param_dtype,

View File

@ -39,7 +39,7 @@ class ParallelMode(ABC):
TODO(@wanchaol): some of these arguments are not necessary for
partitioning, remove the unnecessary ones later.
"""
raise NotImplementedError()
raise NotImplementedError
@abstractmethod
def transform_and_compile(self, gm: GraphModule) -> GraphModule:
@ -51,7 +51,7 @@ class ParallelMode(ABC):
the distributed environment.
"""
# TODO: add more necessary arguments to this interface.
raise NotImplementedError()
raise NotImplementedError
class DataParallel(ParallelMode):

View File

@ -146,7 +146,7 @@ def _iterate_state_dict(
not isinstance(companion_obj, (list, tuple))
or len(companion_obj) != len(iter_object)
):
raise CompanionMismatch()
raise CompanionMismatch
ret = [
_iterate_state_dict(
@ -437,7 +437,7 @@ def _check_state_dict_similarity(
companion_obj: Any,
) -> torch.Tensor:
if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
raise CompanionMismatch()
raise CompanionMismatch
return obj
try:

View File

@ -439,7 +439,7 @@ class ElasticAgent(abc.ABC):
Raises:
Exception - any other failures NOT related to worker process
"""
raise NotImplementedError()
raise NotImplementedError
@abc.abstractmethod
def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
@ -450,7 +450,7 @@ class ElasticAgent(abc.ABC):
Implementors are encouraged (but not required) to return
a defensive read-only copy.
"""
raise NotImplementedError()
raise NotImplementedError
class SimpleElasticAgent(ElasticAgent):
@ -477,7 +477,7 @@ class SimpleElasticAgent(ElasticAgent):
This is according to worker spec for the worker group .
Returns a map of ``local_rank`` to worker ``id``.
"""
raise NotImplementedError()
raise NotImplementedError
@abc.abstractmethod
def _stop_workers(self, worker_group: WorkerGroup) -> None:
@ -487,7 +487,7 @@ class SimpleElasticAgent(ElasticAgent):
``WorkerState``. That is, it must gracefully handle stopping
non-existent workers, unhealthy (stuck) workers, etc.
"""
raise NotImplementedError()
raise NotImplementedError
@abc.abstractmethod
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
@ -495,7 +495,7 @@ class SimpleElasticAgent(ElasticAgent):
This function also returns the new state of the worker group.
"""
raise NotImplementedError()
raise NotImplementedError
@abc.abstractmethod
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
@ -504,7 +504,7 @@ class SimpleElasticAgent(ElasticAgent):
Args:
death_sig: Signal to send to the child process, SIGTERM is default
"""
raise NotImplementedError()
raise NotImplementedError
@staticmethod
def _set_master_addr_port(

View File

@ -458,7 +458,7 @@ class PContext(abc.ABC):
@abc.abstractmethod
def _start(self) -> None:
"""Start processes using strategy defined in a particular context."""
raise NotImplementedError()
raise NotImplementedError
@abc.abstractmethod
def _poll(self) -> Optional[RunProcsResult]:
@ -469,7 +469,7 @@ class PContext(abc.ABC):
successfully or any process fails. Returns ``None`` if
all processes are still running.
"""
raise NotImplementedError()
raise NotImplementedError
def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
"""
@ -514,7 +514,7 @@ class PContext(abc.ABC):
@abc.abstractmethod
def pids(self) -> Dict[int, int]:
"""Return pids of processes mapped by their respective local_ranks."""
raise NotImplementedError()
raise NotImplementedError
@abc.abstractmethod
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
@ -522,7 +522,7 @@ class PContext(abc.ABC):
Terminates all processes managed by this context and cleans up any
meta resources (e.g. redirect, error_file files).
"""
raise NotImplementedError()
raise NotImplementedError
def close(
self, death_sig: Optional[signal.Signals] = None, timeout: int = 30

View File

@ -655,10 +655,10 @@ class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
continue
if action == _Action.ERROR_CLOSED:
raise RendezvousClosedError()
raise RendezvousClosedError
if action == _Action.ERROR_TIMEOUT:
raise RendezvousTimeoutError()
raise RendezvousTimeoutError
if action == _Action.SYNC:
# Delay the execution by one second to avoid overloading the

View File

@ -270,7 +270,7 @@ class EtcdRendezvous:
self._rendezvous_deadline = time.time() + self._timeout
while True:
if time.time() > self._rendezvous_deadline:
raise RendezvousTimeoutError()
raise RendezvousTimeoutError
logger.info("Attempting to join next rendezvous")
try:
@ -340,17 +340,17 @@ class EtcdRendezvous:
logger.info("Observed existing rendezvous state: %s", state)
if state["status"] == "closed":
raise RendezvousClosedError()
raise RendezvousClosedError
if state["status"] == "joinable":
return self.join_phase(state["version"])
if state["status"] == "final":
self.handle_existing_rendezvous(state["version"])
raise EtcdRendezvousRetryImmediately()
raise EtcdRendezvousRetryImmediately
self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1)
raise EtcdRendezvousRetryableFailure()
raise EtcdRendezvousRetryableFailure
def join_phase(self, expected_version):
"""
@ -632,7 +632,7 @@ class EtcdRendezvous:
active_version, state = self.get_rdzv_state()
if state["status"] != "final" or state["version"] != expected_version:
raise EtcdRendezvousRetryImmediately()
raise EtcdRendezvousRetryImmediately
# Increment counter to signal an additional waiting worker.
state["num_workers_waiting"] += 1
@ -714,7 +714,7 @@ class EtcdRendezvous:
pass
if time.time() > self._rendezvous_deadline:
raise RendezvousTimeoutError()
raise RendezvousTimeoutError
active_version, state = self.get_rdzv_state()
def handle_join_last_call(self, expected_version, deadline):
@ -832,7 +832,7 @@ class EtcdRendezvous:
pass
if time.time() > self._rendezvous_deadline:
raise RendezvousTimeoutError()
raise RendezvousTimeoutError
# Unfortunately, we have to do another fetch in order to get last etcd_index.
return self.get_rdzv_state()

View File

@ -963,7 +963,7 @@ class _PatchedFn(NamedTuple):
orig_fn: Any
def revert(self):
raise NotImplementedError()
raise NotImplementedError
class _PatchedFnSetItem(_PatchedFn):

View File

@ -22,7 +22,7 @@ SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
#
class SymDispatchMode:
def __sym_dispatch__(self, func, types, args, kwargs):
raise NotImplementedError()
raise NotImplementedError
def __enter__(self):
global SYM_FUNCTION_MODE

View File

@ -31,7 +31,7 @@ class OperatorSupportBase(abc.ABC):
def is_node_supported(
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
raise NotImplementedError()
raise NotImplementedError
@compatibility(is_backward_compatible=False)

View File

@ -476,7 +476,7 @@ if _enabled:
# RecursiveScriptClass.
def forward_magic_method(self, method_name, *args, **kwargs):
if not self._c._has_method(method_name):
raise TypeError()
raise TypeError
self_method = self.__getattr__(method_name)
return self_method(*args, **kwargs)
@ -865,7 +865,7 @@ if _enabled:
if getattr(self_method, "__func__", None) == getattr(
RecursiveScriptModule, method_name
):
raise NotImplementedError()
raise NotImplementedError
return self_method(*args, **kwargs)
def __iter__(self):

View File

@ -1184,7 +1184,7 @@ class _LazyConvXdMixin(LazyModuleMixin):
# Function to return the number of spatial dims expected for inputs to the module.
# This is expected to be implemented by subclasses.
def _get_num_spatial_dims(self) -> int:
raise NotImplementedError()
raise NotImplementedError
# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter

View File

@ -264,7 +264,7 @@ class Invocation:
# TODO: Implement this.
# Tracks top level call arguments and diagnostic options.
def __init__(self) -> None:
raise NotImplementedError()
raise NotImplementedError
@dataclasses.dataclass

View File

@ -1910,7 +1910,7 @@ class TorchFunctionMode:
pass
def __torch_function__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
raise NotImplementedError
def __enter__(self):
_push_mode(self)

View File

@ -36,7 +36,7 @@ class ErrorMeta(Exception):
super().__init__(
"If you are a user and see this message during normal operation "
"please file an issue at https://github.com/pytorch/pytorch/issues. "
"If you are a developer and working on the comparison functions, please `raise ErrorMeta().to_error()` "
"If you are a developer and working on the comparison functions, please `raise ErrorMeta.to_error()` "
"for user facing errors."
)
self.type = type
@ -336,7 +336,7 @@ class Pair(abc.ABC):
@staticmethod
def _inputs_not_supported() -> NoReturn:
raise UnsupportedInputs()
raise UnsupportedInputs
@staticmethod
def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, Tuple[Type, ...]]):
@ -1217,7 +1217,7 @@ def not_close_error_metas(
)
except ErrorMeta as error_meta:
# Explicitly raising from None to hide the internal traceback
raise error_meta.to_error() from None
raise error_meta.to_error() from None # noqa: RSE102
error_metas: List[ErrorMeta] = []
for pair in pairs:

View File

@ -9390,7 +9390,7 @@ class DistributedTest:
@staticmethod
def backward(ctx, grad_output):
raise RuntimeError()
raise RuntimeError
class MyModel(nn.Module):
def __init__(self, device):
@ -9534,7 +9534,7 @@ class DistributedTest:
@staticmethod
def backward(ctx, grad_output):
raise RuntimeError()
raise RuntimeError
class MyModel(torch.nn.Module):
def __init__(self, device):

View File

@ -67,7 +67,7 @@ class TorchDispatchMode:
self.old_dispatch_mode_flag = False
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
raise NotImplementedError
def __enter__(self):
global _is_in_torch_dispatch_mode

View File

@ -1085,7 +1085,7 @@ class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
target_frame.weak_holders
):
raise _StopRecomputationError()
raise _StopRecomputationError
# See Rule 6: [ retain_graph is True ] above
return x.detach()

View File

@ -12,7 +12,7 @@ class _BaseDatasetFetcher:
self.drop_last = drop_last
def fetch(self, possibly_batched_index):
raise NotImplementedError()
raise NotImplementedError
class _IterableDatasetFetcher(_BaseDatasetFetcher):

View File

@ -92,7 +92,7 @@ class Capture:
if attrname == 'kwarg' or attrname == 'kwargs':
raise Exception('no kwargs!')
if attrname in ['__deepcopy__']:
raise AttributeError()
raise AttributeError
result = CaptureGetAttr(self, attrname, ctx=self.ctx)
return result

View File

@ -79,7 +79,7 @@ class Sampler(Generic[T_co]):
# Calling `len(subclass_instance)` raises:
# TypeError: 'NotImplementedType' object cannot be interpreted as an integer
#
# + `raise NotImplementedError()`:
# + `raise NotImplementedError`:
# This prevents triggering some fallback behavior. E.g., the built-in
# `list(X)` tries to call `len(X)` first, and executes a different code
# path if the method is not found or `NotImplemented` is returned, while

View File

@ -507,7 +507,7 @@ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
elif k is UfuncKey.CPUVector:
compute_t = VectorizedCType(BaseCType(scalar_t))
else:
raise AssertionError()
raise AssertionError
inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
if k not in inner_ufunc_sigs:
inner_ufunc_sigs[k] = UfuncSignature(

View File

@ -1151,7 +1151,7 @@ def compute_cpp_argument_yaml(
arg["default"] = cpp_a.default
return arg
elif isinstance(cpp_a.argument, SelfArgument):
raise AssertionError()
raise AssertionError
elif isinstance(cpp_a.argument, Argument):
return compute_argument_yaml(
cpp_a.argument,