mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261 Approved by: https://github.com/albanD
This commit is contained in:
parent
b726a23d4e
commit
93e249969b
|
|
@ -1858,7 +1858,7 @@ class TimeOutException(Exception):
|
|||
|
||||
|
||||
def alarm_handler(signum, frame):
|
||||
raise TimeOutException()
|
||||
raise TimeOutException
|
||||
|
||||
|
||||
def exit_after(s):
|
||||
|
|
@ -2136,7 +2136,7 @@ class BenchmarkRunner:
|
|||
return set()
|
||||
|
||||
def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def equal_nan(self):
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ def serialize_sparse_tensor(e):
|
|||
|
||||
|
||||
def deserialize_sparse_tensor(size, dtype, layout, is_coalesced, nnz=None):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def deserialize_tensor(size, dtype, stride=None):
|
||||
|
|
|
|||
|
|
@ -126,6 +126,7 @@ select = [
|
|||
"PT025",
|
||||
"PT026",
|
||||
"PYI",
|
||||
"RSE",
|
||||
"RUF008", # mutable dataclass default
|
||||
"RUF015", # access first ele in constant time
|
||||
"RUF016", # type error non-integer index
|
||||
|
|
|
|||
|
|
@ -578,7 +578,7 @@ def main():
|
|||
with open(filename, "w") as f:
|
||||
f.writelines(lines)
|
||||
return
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ import torch
|
|||
|
||||
class Setup:
|
||||
def setup(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def shutdown(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileSetup:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class Dummymodel(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EPModel(nn.Module):
|
||||
|
|
@ -31,7 +31,7 @@ class EPModel(nn.Module):
|
|||
self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SecondTier(nn.Module):
|
||||
|
|
@ -43,7 +43,7 @@ class SecondTier(nn.Module):
|
|||
self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TopModel(nn.Module):
|
||||
|
|
@ -55,7 +55,7 @@ class TopModel(nn.Module):
|
|||
self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TestFSDPWithEP(DTensorTestBase, VerifyStateDictMixin):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class TestMetricsHandler(MetricHandler):
|
|||
class Parent(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def func(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def base_func(self):
|
||||
self.func()
|
||||
|
|
@ -57,7 +57,7 @@ class MetricsApiTest(TestCase):
|
|||
|
||||
@prof
|
||||
def throw(self):
|
||||
raise RuntimeError()
|
||||
raise RuntimeError
|
||||
|
||||
@prof(group="torchelastic")
|
||||
def bar2(self):
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ class _DummyRendezvousHandler(RendezvousHandler):
|
|||
return "dummy_backend"
|
||||
|
||||
def next_rendezvous(self) -> Tuple[Store, int, int]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def _create_c10d_store_mp(is_server, server_addr, port, world_size, wait_for_wor
|
|||
timeout=2,
|
||||
)
|
||||
if store is None:
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
store.set(f"test_key/{os.getpid()}", b"test_value")
|
||||
|
||||
|
|
|
|||
|
|
@ -363,7 +363,7 @@ class TestFSDPOptimState(FSDPTest):
|
|||
# these settings are not implemented since the transformer is
|
||||
# wrapped with FSDP at the top-level, which means that there is
|
||||
# only a single flat parameter, making these booleans vacuous
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
if group is None:
|
||||
group = dist.distributed_c10d._get_default_group()
|
||||
model = TransformerWithSharedParams.init(
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ def test_exception_no_hang(setup_rpc):
|
|||
|
||||
class Raise(nn.Module):
|
||||
def forward(self, x):
|
||||
raise ExpectedException()
|
||||
raise ExpectedException
|
||||
|
||||
model = nn.Sequential(Pass(), Pass(), Raise())
|
||||
model = Pipe(model, chunks=3)
|
||||
|
|
|
|||
|
|
@ -231,7 +231,7 @@ def test_exception(setup_rpc):
|
|||
|
||||
class Raise(nn.Module):
|
||||
def forward(self, *_):
|
||||
raise ExpectedException()
|
||||
raise ExpectedException
|
||||
|
||||
model = nn.Sequential(Raise())
|
||||
model = Pipe(model, chunks=1)
|
||||
|
|
@ -265,7 +265,7 @@ def test_exception_early_stop_asap(setup_rpc):
|
|||
|
||||
class Raise(nn.Module):
|
||||
def forward(self, x):
|
||||
raise ExpectedException()
|
||||
raise ExpectedException
|
||||
|
||||
model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
|
||||
model = Pipe(model, chunks=3)
|
||||
|
|
|
|||
|
|
@ -752,7 +752,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||
new_data = args[0]._data.view(*args[1:])
|
||||
return FooTensor(new_data, args[0]._config, args[0]._scale)
|
||||
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
class foo_autograd_fn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ from user code:
|
|||
def test_internal_error_suppress_errors(self, records):
|
||||
def fn001(x):
|
||||
def f(ctx):
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
comptime(f)
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ WON'T CONVERT fn001 test_exc.py line N
|
|||
========== TorchDynamo Stack Trace ==========
|
||||
Traceback (most recent call last):
|
||||
File "test_exc.py", line N, in f
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
AssertionError:
|
||||
|
||||
from user code:
|
||||
|
|
@ -84,7 +84,7 @@ from user code:
|
|||
def test_not_implemented_error(self, records):
|
||||
def fn001(x):
|
||||
def f(ctx):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
# Ensure graph break is not possible
|
||||
for i in range(3):
|
||||
|
|
@ -101,7 +101,7 @@ WON'T CONVERT fn001 test_exc.py line N
|
|||
due to:
|
||||
Traceback (most recent call last):
|
||||
File "test_exc.py", line N, in f
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
torch._dynamo.exc.InternalTorchDynamoError:
|
||||
|
||||
from user code:
|
||||
|
|
@ -128,7 +128,7 @@ from user code:
|
|||
# NB: avoid decorator, as 3.11 changed the line number attributed
|
||||
# in this situation
|
||||
def f(ctx):
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
comptime(f)
|
||||
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ from user code:
|
|||
import torch._inductor.lowering
|
||||
|
||||
def throw(x):
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
# inject an error in the lowerings
|
||||
dict_entries = {}
|
||||
|
|
@ -189,7 +189,7 @@ WON'T CONVERT inductor_error_fn test_logging.py line N
|
|||
due to:
|
||||
Traceback (most recent call last):
|
||||
File "test_logging.py", line N, in throw
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
|
||||
LoweringException: AssertionError:
|
||||
target: aten.round.default
|
||||
|
|
|
|||
|
|
@ -396,12 +396,12 @@ class ListConfig:
|
|||
if self.resolve:
|
||||
x = x._dereference_node()
|
||||
if x._is_missing():
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
self.index = self.index + 1
|
||||
if isinstance(x, ListConfig.ValueNode):
|
||||
return x._value()
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
def __iter__(self):
|
||||
return self._iter_ex(True)
|
||||
|
|
@ -410,7 +410,7 @@ class ListConfig:
|
|||
try:
|
||||
return ListConfig.ListIterator(self, resolve)
|
||||
except Exception:
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
def __init__(self):
|
||||
self._content = [
|
||||
|
|
@ -545,7 +545,7 @@ def apply_chunking_to_forward(forward_fn, *input_tensors):
|
|||
assert all(input_tensor.shape[1] == tensor_shape for input_tensor in input_tensors)
|
||||
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
||||
if num_args_in_forward_chunk_fn != len(input_tensors):
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
|
||||
return forward_fn(*input_tensors)
|
||||
|
||||
|
|
@ -848,7 +848,7 @@ def _merge_criteria_processor_list(default_list, custom_list):
|
|||
for default in default_list:
|
||||
for custom in custom_list:
|
||||
if type(custom) is type(default):
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
default_list.extend(custom_list)
|
||||
return default_list
|
||||
|
||||
|
|
@ -2573,7 +2573,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
if self.i < 3:
|
||||
self.i += 1
|
||||
return self.i
|
||||
raise StopIteration()
|
||||
raise StopIteration
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x):
|
||||
|
|
|
|||
|
|
@ -147,10 +147,10 @@ class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
class Foo(list):
|
||||
def __iter__(self):
|
||||
raise Exception()
|
||||
raise Exception
|
||||
|
||||
def __len__(self):
|
||||
raise Exception()
|
||||
raise Exception
|
||||
|
||||
x = Foo()
|
||||
x.append(torch.randn(4))
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ class StructuredTraceTest(TestCase):
|
|||
import torch._inductor.lowering
|
||||
|
||||
def throw(x):
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
# inject an error in the lowerings
|
||||
dict_entries = {}
|
||||
|
|
|
|||
|
|
@ -732,12 +732,12 @@ class Operator:
|
|||
|
||||
def any_opinfo_attr(self, attr):
|
||||
if not self.has_opinfo():
|
||||
raise RuntimeError()
|
||||
raise RuntimeError
|
||||
return any(getattr(opinfo, attr) for opinfo in self.opinfos)
|
||||
|
||||
def all_opinfo_attr(self, attr):
|
||||
if not self.has_opinfo():
|
||||
raise RuntimeError()
|
||||
raise RuntimeError
|
||||
return all(getattr(opinfo, attr) for opinfo in self.opinfos)
|
||||
|
||||
def supports_vjp(self):
|
||||
|
|
@ -870,7 +870,7 @@ class OperatorSet:
|
|||
elif n.startswith(torch_dot):
|
||||
names_sanitized.append(n[len(torch_dot) :])
|
||||
else:
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
return cls.from_names(names_sanitized)
|
||||
|
||||
def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)):
|
||||
|
|
|
|||
|
|
@ -403,7 +403,7 @@ class TestMin(TestCase):
|
|||
# test with too many elements
|
||||
try:
|
||||
A[1, ..., 1, 1]
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
except IndexError:
|
||||
pass
|
||||
c, d = dims()
|
||||
|
|
@ -415,7 +415,7 @@ class TestMin(TestCase):
|
|||
)
|
||||
try:
|
||||
A[..., 3, ...]
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
except DimensionBindError:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -285,7 +285,7 @@ class TestInductorDynamic(TestCase):
|
|||
|
||||
@custom_ops.custom_op("test::foo")
|
||||
def foo(x: torch.Tensor, y: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl("test::foo")
|
||||
def foo_impl(x: torch.Tensor, y: int) -> torch.Tensor:
|
||||
|
|
@ -401,7 +401,7 @@ class TestInductorDynamic(TestCase):
|
|||
|
||||
@custom_ops.custom_op("test::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl("test::foo")
|
||||
def foo_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
|||
|
|
@ -2632,7 +2632,7 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
res3 = torch._C._nn.linear(in_tensor2, self.w2, self.b2)
|
||||
res4 = torch._C._nn.linear(in_tensor1, self.w2, self.b1)
|
||||
else:
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b1)
|
||||
return res1, res2, res3, res4
|
||||
|
||||
|
|
|
|||
|
|
@ -2871,7 +2871,7 @@ class TestScriptList(JitTestCase):
|
|||
|
||||
def __next__(self):
|
||||
if self.value == limit: # noqa: F821
|
||||
raise StopIteration()
|
||||
raise StopIteration
|
||||
|
||||
ret = self.value
|
||||
self.value += 1
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class ClosuresTest(TestCase):
|
|||
torch._lazy.add_step_closure(closure)
|
||||
torch._lazy.mark_step()
|
||||
|
||||
raise AssertionError() # Should not reach here
|
||||
raise AssertionError # Should not reach here
|
||||
except RuntimeError as e:
|
||||
assert flag.is_set(), "Should have caught exception from closure"
|
||||
|
||||
|
|
@ -79,7 +79,7 @@ class ClosuresTest(TestCase):
|
|||
torch._lazy.add_step_closure(closure2, run_async=True)
|
||||
torch._lazy.mark_step()
|
||||
|
||||
raise AssertionError() # Should not reach here
|
||||
raise AssertionError # Should not reach here
|
||||
except RuntimeError as e:
|
||||
# Should have caught exception from closure1
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ class TestOperators(common_utils.TestCase):
|
|||
def symbolic(g, x):
|
||||
# The inside of this function should never be invoked, because
|
||||
# we will fail due to an argument mismatch first.
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class Errors:
|
|||
NB: It is an error to "fail" without having added any errors to
|
||||
the error context.
|
||||
"""
|
||||
raise self.exc_class()
|
||||
raise self.exc_class
|
||||
|
||||
def failWith(self, msg):
|
||||
"""
|
||||
|
|
@ -489,7 +489,7 @@ def verify(
|
|||
errs.requireEqual(
|
||||
proto_bytes.getvalue(), alt_proto_bytes.getvalue()
|
||||
)
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
# TODO: test that the traced model also returns the same thing...
|
||||
run_helper(torch_out, args, remained_onnx_input_idx)
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class TestImporter(PackageTestCase):
|
|||
self._whichmodule_return = whichmodule_return
|
||||
|
||||
def import_module(self, module_name):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def whichmodule(self, obj, name):
|
||||
return self._whichmodule_return
|
||||
|
|
|
|||
|
|
@ -1970,7 +1970,7 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
|||
try:
|
||||
with cm:
|
||||
x.add(y)
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
x.relu()
|
||||
except ValueError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -509,7 +509,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def baz(x: Tensor) -> Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def test_unsupported_schemas(self):
|
||||
with self.assertRaisesRegex(ValueError, "only supports functional"):
|
||||
|
|
@ -670,35 +670,35 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
with self.assertRaisesRegex(ValueError, "varargs"):
|
||||
|
||||
def foo(*args):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
infer_schema(foo)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "varkwargs"):
|
||||
|
||||
def foo(**kwargs):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
infer_schema(foo)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "must have a type annotation"):
|
||||
|
||||
def foo(x):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
infer_schema(foo)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "unsupported"):
|
||||
|
||||
def foo(x: Tensor) -> Tuple[Tensor, ...]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
infer_schema(foo)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "can be mutated"):
|
||||
|
||||
def foo(x: Tensor, y: int) -> Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
infer_schema(foo, mutates_args={"y"})
|
||||
|
||||
|
|
@ -752,7 +752,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{self.test_ns}::foo")
|
||||
def foo(x: Tensor) -> typ:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{self.test_ns}::foo")
|
||||
def foo_impl(x: Tensor) -> typ:
|
||||
|
|
@ -771,7 +771,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{self.test_ns}::foo")
|
||||
def foo(x: Tensor) -> Tuple[typ, typ]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{self.test_ns}::foo")
|
||||
def foo_impl(x: Tensor) -> Tuple[typ, typ]:
|
||||
|
|
@ -789,7 +789,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: Tensor, y: typ) -> Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
yeet = None
|
||||
|
||||
|
|
@ -823,7 +823,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{self.test_ns}::foo")
|
||||
def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
called = 0
|
||||
|
||||
|
|
@ -847,7 +847,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
del foo
|
||||
|
||||
|
|
@ -855,7 +855,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
# int[N] in Dispatcher is a bit wild, so we don't try to support it.
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
del foo
|
||||
|
||||
|
|
@ -863,7 +863,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: Tensor, y: Callable) -> Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
del foo
|
||||
|
||||
|
|
@ -910,7 +910,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{ns}::foo2")
|
||||
def foo2(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def test_private_ctor(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"):
|
||||
|
|
@ -919,7 +919,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_lifetime(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo")
|
||||
|
||||
|
|
@ -928,7 +928,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
# Unless we delete the original op.
|
||||
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
|
||||
|
|
@ -936,14 +936,14 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
# Smoke test
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
custom_ops._destroy(f"{TestCustomOp.test_ns}::foo")
|
||||
|
||||
def test_autograd_notimplemented(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
op = self.get_op(f"{self.test_ns}::foo")
|
||||
|
|
@ -954,7 +954,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
y = torch.randn(3)
|
||||
|
|
@ -966,7 +966,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
y = torch.randn(3)
|
||||
|
|
@ -978,7 +978,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_autograd_notimplemented_gradmode(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(x, y):
|
||||
|
|
@ -994,7 +994,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_impl_cpu(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
|
||||
def foo_cpu(x):
|
||||
|
|
@ -1008,7 +1008,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_impl_invalid_devices(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def foo_impl(x):
|
||||
return x.sin()
|
||||
|
|
@ -1033,7 +1033,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_partially_registered(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(x):
|
||||
|
|
@ -1054,7 +1054,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_save_for_backward_inputs_are_namedtuple(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(x):
|
||||
|
|
@ -1084,7 +1084,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_returns_dict(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(x):
|
||||
|
|
@ -1107,7 +1107,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_dict_invalid_keys(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(x):
|
||||
|
|
@ -1130,7 +1130,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_dict_grad_for_nontensor(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(x, dim):
|
||||
|
|
@ -1153,7 +1153,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_dict_requires_keys_for_input_tensors(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(x, y):
|
||||
|
|
@ -1176,7 +1176,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_dict_requires_keys_for_input_optional_tensors(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(x, y):
|
||||
|
|
@ -1199,7 +1199,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_grads_are_tensor_or_none(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(x):
|
||||
|
|
@ -1222,7 +1222,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(xs):
|
||||
|
|
@ -1245,7 +1245,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(xs):
|
||||
|
|
@ -1268,7 +1268,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_tensorlist_input_requires_list_grads(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(xs):
|
||||
|
|
@ -1291,7 +1291,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_output_differentiability_type(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
|
||||
|
||||
|
|
@ -1304,7 +1304,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_output_differentiability_numel(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output_differentiability"):
|
||||
|
||||
|
|
@ -1317,7 +1317,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_output_differentiability_tensorlist(self):
|
||||
@custom_ops.custom_op(f"{self.test_ns}::foo")
|
||||
def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{self.test_ns}::foo")
|
||||
def foo_impl(x):
|
||||
|
|
@ -1343,7 +1343,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_backward_output_differentiability_non_tensor(self):
|
||||
@custom_ops.custom_op(f"{self.test_ns}::foo")
|
||||
def foo(x: Tensor) -> Tuple[Tensor, int]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{self.test_ns}::foo")
|
||||
def foo_impl(x):
|
||||
|
|
@ -1368,7 +1368,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_impl_separate(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu")
|
||||
def foo_cpu(x):
|
||||
|
|
@ -1392,7 +1392,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_impl_multiple(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@custom_ops.impl(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo_impl(x):
|
||||
|
|
@ -1422,7 +1422,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_impl_meta(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
def foo_meta(x, dim):
|
||||
|
|
@ -1438,7 +1438,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_duplicate_impl(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
def foo_meta(x, dim):
|
||||
|
|
@ -1457,7 +1457,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_new_data_dependent_symint(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
def foo_meta(x):
|
||||
|
|
@ -1483,7 +1483,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
# this one is just a sanity check.
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
def foo_meta(x):
|
||||
|
|
@ -1497,7 +1497,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
def test_not_implemented_error(self):
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo")
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
x = torch.randn(3)
|
||||
op = self.get_op(f"{self.test_ns}::foo")
|
||||
|
|
@ -1510,7 +1510,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||
|
||||
@custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar")
|
||||
def bar(sizes: Sequence[int]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
op = self.get_op(f"{self.test_ns}::bar")
|
||||
with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
|
||||
|
|
@ -2021,7 +2021,7 @@ class MiniOpTest(CustomOpTestCaseBase):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def autograd_impl(x):
|
||||
return Op.apply(x)
|
||||
|
|
|
|||
|
|
@ -492,7 +492,7 @@ class TestBinary(TestCase):
|
|||
mt1 = masked_tensor(data1, mask1)
|
||||
try:
|
||||
fn(mt0, mt1)
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
except ValueError as e:
|
||||
assert (
|
||||
"Input masks must match. If you need support for this, please open an issue on Github."
|
||||
|
|
|
|||
|
|
@ -6603,7 +6603,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
elif weight_layout == torch.sparse_coo:
|
||||
module.weight = nn.Parameter(module.weight.to_sparse_coo())
|
||||
else:
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
inp = torch.randn(4, requires_grad=True, device=device)
|
||||
res = module(inp)
|
||||
|
|
|
|||
|
|
@ -1206,7 +1206,7 @@ class TestTorchFunctionMode(TestCase):
|
|||
|
||||
class A(TorchFunctionMode):
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
raise ErrorA()
|
||||
raise ErrorA
|
||||
|
||||
with self.assertRaises(ErrorA):
|
||||
with A():
|
||||
|
|
@ -1218,7 +1218,7 @@ class TestTorchFunctionMode(TestCase):
|
|||
|
||||
class A(TorchFunctionMode):
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
raise ErrorA()
|
||||
raise ErrorA
|
||||
|
||||
x = A()
|
||||
with self.assertRaises(ErrorA):
|
||||
|
|
|
|||
|
|
@ -1238,7 +1238,7 @@ $3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
|
|||
class AMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if func.__name__ == 'randn.default':
|
||||
raise RuntimeError()
|
||||
raise RuntimeError
|
||||
return A(torch.zeros(()))
|
||||
|
||||
with AMode():
|
||||
|
|
@ -1254,7 +1254,7 @@ $3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
|
|||
|
||||
class A(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise ErrorA()
|
||||
raise ErrorA
|
||||
|
||||
x = A()
|
||||
with self.assertRaises(ErrorA):
|
||||
|
|
|
|||
|
|
@ -1081,7 +1081,7 @@ class TestTraceback(TestCase):
|
|||
source = '''\
|
||||
def f(x):
|
||||
def g(x):
|
||||
raise RuntimeError() # HEYA
|
||||
raise RuntimeError # HEYA
|
||||
|
||||
x = x * 3
|
||||
return g(x) + 1
|
||||
|
|
@ -1099,7 +1099,7 @@ def f(x):
|
|||
|
||||
def test_format_traceback_short(self):
|
||||
try:
|
||||
raise RuntimeError()
|
||||
raise RuntimeError
|
||||
except RuntimeError as e:
|
||||
self.assertRegex(format_traceback_short(e.__traceback__), r'.*test_utils.py:\d+ in test_format_traceback_short')
|
||||
|
||||
|
|
|
|||
|
|
@ -525,7 +525,7 @@ class WeakKeyDictionaryTestCase(TestCase):
|
|||
return self
|
||||
|
||||
def __next__(self):
|
||||
raise Exc()
|
||||
raise Exc
|
||||
|
||||
self.assertRaises(Exc, d.update, badseq())
|
||||
|
||||
|
|
@ -866,7 +866,7 @@ class WeakKeyDictionaryScriptObjectTestCase(TestCase):
|
|||
return self
|
||||
|
||||
def __next__(self):
|
||||
raise Exc()
|
||||
raise Exc
|
||||
|
||||
self.assertRaises(Exc, d.update, badseq())
|
||||
|
||||
|
|
|
|||
|
|
@ -1136,14 +1136,14 @@ class TestCreation(TestCase):
|
|||
return 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
|
||||
class Map:
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise KeyError()
|
||||
raise KeyError
|
||||
|
||||
a = np.array([Map()])
|
||||
assert_(a.shape == (1,))
|
||||
|
|
@ -1160,7 +1160,7 @@ class TestCreation(TestCase):
|
|||
if ind in [0, 1]:
|
||||
return ind
|
||||
else:
|
||||
raise IndexError()
|
||||
raise IndexError
|
||||
|
||||
d = np.array([Point2(), Point2(), Point2()])
|
||||
assert_equal(d.dtype, np.dtype(object))
|
||||
|
|
|
|||
|
|
@ -342,7 +342,7 @@ class TestFFT1D(TestCase):
|
|||
Y_res = fft(Y, axes=ax)
|
||||
assert_allclose(X_res, Y_res, atol=_tol, rtol=_tol)
|
||||
else:
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
|
||||
|
||||
@skipif(IS_WASM, reason="Cannot start thread")
|
||||
|
|
|
|||
|
|
@ -1435,7 +1435,7 @@ class TestVectorize(TestCase):
|
|||
try:
|
||||
vectorize(random.randrange) # Should succeed
|
||||
except Exception:
|
||||
raise AssertionError() # noqa: TRY200
|
||||
raise AssertionError # noqa: TRY200
|
||||
|
||||
def test_keywords2_ticket_2100(self):
|
||||
# Test kwarg support: enhancement ticket 2100
|
||||
|
|
|
|||
|
|
@ -1252,7 +1252,7 @@ def emit_body(
|
|||
if a.name == derivative_var_name:
|
||||
break
|
||||
else:
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
return f"grad_fn->should_compute_output({edge_off})"
|
||||
|
||||
if is_inplace_foreach:
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ def custom_op(qualname, func_or_schema=None):
|
|||
>>> # we will infer the types of the inputs and outputs.
|
||||
>>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
|
||||
>>> def numpy_sin(x: Tensor) -> Tensor:
|
||||
>>> raise NotImplementedError()
|
||||
>>> raise NotImplementedError
|
||||
>>>
|
||||
>>> # The custom op is now accessible via the torch.ops module:
|
||||
>>> torch.ops.mylibrary.numpy_sin
|
||||
|
|
@ -143,7 +143,7 @@ def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
|
|||
>>> # we will infer the types of the inputs and outputs.
|
||||
>>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
|
||||
>>> def numpy_cos(x: Tensor) -> Tensor:
|
||||
>>> raise NotImplementedError()
|
||||
>>> raise NotImplementedError
|
||||
>>>
|
||||
>>> # The custom op is now accessible via the torch.ops module:
|
||||
>>> torch.ops.mylibrary.numpy_cos
|
||||
|
|
@ -207,7 +207,7 @@ def impl_abstract(qualname, *, func=None):
|
|||
>>> # Example 1: an operator without data-dependent output shape
|
||||
>>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
|
||||
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
|
||||
>>> raise NotImplementedError()
|
||||
>>> raise NotImplementedError
|
||||
>>>
|
||||
>>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
|
||||
>>> def custom_linear_abstract(x, weight):
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ class TestingOnlyCompileError(Exception):
|
|||
def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs):
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == torch.relu:
|
||||
raise ReluCompileError()
|
||||
raise ReluCompileError
|
||||
return gm
|
||||
|
||||
|
||||
|
|
@ -165,7 +165,7 @@ def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs
|
|||
return gm
|
||||
for t in example_inputs:
|
||||
if not t.is_leaf:
|
||||
raise TestingOnlyCompileError()
|
||||
raise TestingOnlyCompileError
|
||||
return gm
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class DeviceInterface(metaclass=DeviceInterfaceMeta):
|
|||
|
||||
class device:
|
||||
def __new__(cls, device: _device_t):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
class Worker:
|
||||
"""
|
||||
|
|
@ -51,71 +51,71 @@ class DeviceInterface(metaclass=DeviceInterfaceMeta):
|
|||
|
||||
@staticmethod
|
||||
def set_device(device: int):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def current_device() -> int:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_device_properties(device: _device_t = None):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def current_device():
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def set_device(device: _device_t):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def maybe_exchange_device(device: int) -> int:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def exchange_device(device: int) -> int:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def device_count():
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def stream(stream: torch.Stream):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def current_stream():
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def set_stream(stream: torch.Stream):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _set_stream_by_id(stream_id: int, device_index: int, device_type: int):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_raw_stream():
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def synchronize(device: _device_t = None):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_device_properties(device: _device_t = None):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: _device_t = None):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DeviceGuard:
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ class EphemeralSource(Source):
|
|||
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
|
||||
|
||||
def make_guard(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def is_ephemeral(self):
|
||||
return True
|
||||
|
|
@ -277,7 +277,7 @@ class NegateSource(ChainedSource):
|
|||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def guard_source(self):
|
||||
return self.base.guard_source()
|
||||
|
|
@ -516,7 +516,7 @@ class ConstantSource(Source):
|
|||
return self.source_name
|
||||
|
||||
def make_guard(self, fn):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
|
|
|||
|
|
@ -1242,7 +1242,7 @@ class InstructionTranslatorBase(
|
|||
if (
|
||||
isinstance(val, BuiltinVariable) and val.fn is StopIteration
|
||||
) or isinstance(val, variables.StopIterationVariable):
|
||||
raise exc.UserStopIteration()
|
||||
raise exc.UserStopIteration
|
||||
unimplemented(f"raise {exc}")
|
||||
else:
|
||||
unimplemented("raise ... from ...")
|
||||
|
|
@ -2231,7 +2231,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
return self.f_locals[source.local_name]
|
||||
if isinstance(source, GlobalSource):
|
||||
return self.f_globals[source.global_name]
|
||||
raise KeyError()
|
||||
raise KeyError
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
|
@ -2388,7 +2388,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
else create_instruction("RETURN_CONST", argval=inst.argval)
|
||||
)
|
||||
self.output.add_output_instructions([return_inst])
|
||||
raise ReturnValueOp()
|
||||
raise ReturnValueOp
|
||||
|
||||
def RETURN_VALUE(self, inst):
|
||||
self._return(inst)
|
||||
|
|
@ -2637,7 +2637,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
self.output.root_tx.mutated_closure_cell_contents.add(
|
||||
maybe_cell.source.name()
|
||||
)
|
||||
raise exc.UnspecializeRestartAnalysis()
|
||||
raise exc.UnspecializeRestartAnalysis
|
||||
unimplemented("write to __closure__ while inlining")
|
||||
|
||||
def LOAD_DEREF(self, inst):
|
||||
|
|
@ -2676,12 +2676,12 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||
def RETURN_VALUE(self, inst):
|
||||
self.symbolic_result = self.pop() # type: ignore[assignment]
|
||||
self.instruction_pointer = None
|
||||
raise ReturnValueOp()
|
||||
raise ReturnValueOp
|
||||
|
||||
def RETURN_CONST(self, inst):
|
||||
self.symbolic_result = self._load_const(inst)
|
||||
self.instruction_pointer = None
|
||||
raise ReturnValueOp()
|
||||
raise ReturnValueOp
|
||||
|
||||
|
||||
class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
|
||||
|
|
|
|||
|
|
@ -219,17 +219,17 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||
def make_guard(self, fn):
|
||||
if self.source:
|
||||
return self.source.make_guard(fn)
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def const_getattr(self, tx, name: str) -> Any:
|
||||
"""getattr(self, name) returning a python constant"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
||||
"""getattr(self, name) returning a new variable"""
|
||||
value = self.const_getattr(tx, name)
|
||||
if not variables.ConstantVariable.is_literal(value):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
source = None
|
||||
if self.source:
|
||||
source = AttrSource(self.source, name)
|
||||
|
|
@ -257,7 +257,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||
return None
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def can_reconstruct(self, tx):
|
||||
"""If it is possible to reconstruct the Python object this
|
||||
|
|
@ -273,7 +273,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||
return False
|
||||
|
||||
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def has_unpack_var_sequence(self, tx) -> bool:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ class ConstantVariable(VariableTracker):
|
|||
)
|
||||
member = getattr(self.value, name)
|
||||
if callable(member):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
return member
|
||||
|
||||
def call_method(
|
||||
|
|
@ -212,5 +212,5 @@ class EnumVariable(VariableTracker):
|
|||
def const_getattr(self, tx, name):
|
||||
member = getattr(self.value, name)
|
||||
if callable(member):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
return member
|
||||
|
|
|
|||
|
|
@ -420,7 +420,7 @@ class DictView(VariableTracker):
|
|||
def view_items_vt(self):
|
||||
# Returns an iterable of the unpacked items
|
||||
# Implement in the subclasses
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
def unwrap(x):
|
||||
|
|
@ -615,7 +615,7 @@ class DataClassVariable(ConstDictVariable):
|
|||
assert self.is_matching_cls(user_cls)
|
||||
|
||||
def as_proxy(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.extend_output([codegen._create_load_const(self.user_cls)])
|
||||
|
|
@ -737,14 +737,14 @@ class CustomizedDictVariable(ConstDictVariable):
|
|||
# called from builder.py
|
||||
@classmethod
|
||||
def wrap(cls, builder, obj):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, items, user_cls, **options):
|
||||
super().__init__(items, user_cls, **options)
|
||||
assert self.is_matching_cls(user_cls)
|
||||
|
||||
def as_proxy(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
# 'RETURN_VALUE triggered compile'
|
||||
# called from torch/_dynamo/codegen.py
|
||||
|
|
|
|||
|
|
@ -439,7 +439,7 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|||
|
||||
def get_function(self):
|
||||
if self.closure:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
func = types.FunctionType(
|
||||
self.code.as_python_constant(),
|
||||
self.f_globals,
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ class RangeVariable(BaseListVariable):
|
|||
elif len(items_to_map) == 3:
|
||||
start, stop, step = items_to_map
|
||||
else:
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
assert stop is not None
|
||||
super().__init__([start, stop, step], **kwargs)
|
||||
|
|
@ -592,7 +592,7 @@ class SliceVariable(BaseListVariable):
|
|||
elif len(items_to_map) == 3:
|
||||
start, stop, step = items_to_map
|
||||
else:
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
if isinstance(start, variables.TensorVariable) or isinstance(
|
||||
stop, variables.TensorVariable
|
||||
|
|
@ -644,7 +644,7 @@ class ListIteratorVariable(VariableTracker):
|
|||
assert self.mutable_local
|
||||
old_index = self.index
|
||||
if old_index >= len(self.items):
|
||||
raise StopIteration()
|
||||
raise StopIteration
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.index += 1
|
||||
return self.items[old_index]
|
||||
|
|
@ -665,7 +665,7 @@ class ListIteratorVariable(VariableTracker):
|
|||
|
||||
def as_python_constant(self):
|
||||
if self.index > 0:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
return iter([x.as_python_constant() for x in self.items])
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
|
|
@ -748,7 +748,7 @@ class RestrictedListSubclassVariable(ListVariable):
|
|||
return [x.as_proxy() for x in self.items]
|
||||
|
||||
def as_python_constant(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def is_python_constant(self):
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -592,13 +592,13 @@ class GetAttrVariable(VariableTracker):
|
|||
|
||||
def const_getattr(self, tx, name):
|
||||
if not isinstance(self.obj, variables.NNModuleVariable):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
step1 = tx.output.get_submodule(self.obj.module_key)
|
||||
if self.name not in step1.__dict__:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
step2 = inspect.getattr_static(step1, self.name)
|
||||
if name not in step2.__dict__:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
return inspect.getattr_static(step2, name)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ class NNModuleVariable(VariableTracker):
|
|||
# Mark the class dynamic unless its module initialization
|
||||
if tx.f_code.co_name != "__init__":
|
||||
GenerationTracker.mark_class_dynamic(type(mod))
|
||||
raise UnspecializeRestartAnalysis()
|
||||
raise UnspecializeRestartAnalysis
|
||||
|
||||
def _custom_getattr_fallback(self, base, tx, name, options):
|
||||
"""Check for a __getattr__ and handle it specially if it is implemented"""
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||
):
|
||||
return self.value.param_groups[arg.source.index]
|
||||
|
||||
raise ArgMappingException()
|
||||
raise ArgMappingException
|
||||
|
||||
new_args = [map_arg(arg) for arg in args]
|
||||
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ class TensorVariable(VariableTracker):
|
|||
return SourcelessBuilder.create(tx, example_value)
|
||||
|
||||
if not (self.source and self.source.subguards_allowed()):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
# For local source, we associate the real value. We use this real value
|
||||
# for implementing getattr fallthrough on the variable tracker base class.
|
||||
|
|
@ -238,23 +238,23 @@ class TensorVariable(VariableTracker):
|
|||
# Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
|
||||
_input_associated_real_value = eval(self.source.name(), scope)
|
||||
except Exception as exc:
|
||||
raise NotImplementedError() from exc
|
||||
raise NotImplementedError from exc
|
||||
|
||||
if _input_associated_real_value is None:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
if object_has_getattribute(_input_associated_real_value):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
if get_custom_getattr(_input_associated_real_value):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
real_value = getattr(_input_associated_real_value, name)
|
||||
if callable(real_value):
|
||||
# Callables have more nuanced handling, and we should let the existing system delegate here.
|
||||
# Raising was past behavior and so should always be sound to fall back.
|
||||
# Note - at a certain point we may want to handle
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
from ..guards import GuardBuilder
|
||||
from .builder import VariableBuilder
|
||||
|
|
@ -391,7 +391,7 @@ class TensorVariable(VariableTracker):
|
|||
result = self.dynamic_getattr(tx, name)
|
||||
|
||||
if result is None:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
return result
|
||||
|
||||
def has_unpack_var_sequence(self, tx):
|
||||
|
|
@ -1090,7 +1090,7 @@ class NumpyNdarrayVariable(TensorVariable):
|
|||
elif name in ["__version__"]:
|
||||
unimplemented("delegate np.__version__ to NumPy")
|
||||
if result is None:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class FuncTorchInterpreter(ABC):
|
|||
return self._cptr.key()
|
||||
|
||||
def get_state(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def check_state(self, state):
|
||||
return state == self.get_state()
|
||||
|
|
|
|||
|
|
@ -797,17 +797,17 @@ class Source:
|
|||
return False
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def name(self) -> str:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def make_guard(self, fn) -> Guard:
|
||||
if self.guard_source() is GuardSource.CONSTANT:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
return Guard(self, fn)
|
||||
|
||||
def is_nn_module(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -422,7 +422,7 @@ class BenchmarkRequest:
|
|||
def make_run_fn(
|
||||
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
|
||||
) -> Callable[[], None]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup_run_fn(self) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -446,7 +446,7 @@ def _reduce_tensor(t):
|
|||
# TODO: These tensors don't currently pickle, so we can't cache a
|
||||
# compiled graph containing them. Just fail now. If mkldnn tensors
|
||||
# get pickling support, we can remove this.
|
||||
raise BypassFxGraphCache()
|
||||
raise BypassFxGraphCache
|
||||
|
||||
# Very large tensors could be expensive to copy to cpu and hash. Let's
|
||||
# at least report if we find slowness.
|
||||
|
|
@ -598,7 +598,7 @@ class FxGraphHashDetails:
|
|||
# Some configs options are callables, e.g., post_grad_custom_pre_pass,
|
||||
# and may not pickle.
|
||||
log.debug("Can't pickle inductor config: %s", e)
|
||||
raise BypassFxGraphCache() from e
|
||||
raise BypassFxGraphCache from e
|
||||
|
||||
def debug_str(self) -> str:
|
||||
"""
|
||||
|
|
@ -843,19 +843,19 @@ class FxGraphCache:
|
|||
"""
|
||||
# Freezing can embed constants that wouldn't be static across runs.
|
||||
if config.freezing or config.aot_inductor.use_runtime_constant_folding:
|
||||
raise BypassFxGraphCache()
|
||||
raise BypassFxGraphCache
|
||||
|
||||
# The treatment of guards in the caching implementation requires that
|
||||
# we have a shape env.
|
||||
if FxGraphCache._get_shape_env() is None:
|
||||
log.debug("fx graph cache no shape env")
|
||||
raise BypassFxGraphCache()
|
||||
raise BypassFxGraphCache
|
||||
|
||||
# HigherOrderOperators should be handled on a case-by-case basis.
|
||||
# Currently, we just skip caching if we have any.
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.HigherOrderOperator):
|
||||
raise BypassFxGraphCache()
|
||||
raise BypassFxGraphCache
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
|
|
@ -990,7 +990,7 @@ def cpp_compiler_search(search: str) -> str:
|
|||
return cxx
|
||||
except (subprocess.SubprocessError, FileNotFoundError, ImportError):
|
||||
continue
|
||||
raise exc.InvalidCxxCompiler()
|
||||
raise exc.InvalidCxxCompiler
|
||||
|
||||
|
||||
def install_gcc_via_conda() -> str:
|
||||
|
|
@ -2745,7 +2745,7 @@ def _worker_compile_triton(
|
|||
|
||||
class CodeCacheFuture:
|
||||
def result(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TritonFuture(CodeCacheFuture):
|
||||
|
|
|
|||
|
|
@ -88,16 +88,16 @@ device_codegens: Dict[str, DeviceCodegen] = {}
|
|||
|
||||
class DeviceOpOverrides:
|
||||
def import_get_raw_stream_as(self, name):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def set_device(self, device_idx):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def synchronize(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def device_guard(self, device_idx):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
|
||||
|
|
@ -1368,7 +1368,7 @@ class Kernel(CodeGen):
|
|||
self.cse = cse
|
||||
|
||||
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def indirect_load(self, name: str, index: sympy.Expr):
|
||||
"""A load the depends on an index we have read"""
|
||||
|
|
@ -1381,12 +1381,12 @@ class Kernel(CodeGen):
|
|||
self.loads = prior
|
||||
|
||||
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def store(
|
||||
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def reduction(
|
||||
self,
|
||||
|
|
@ -1395,7 +1395,7 @@ class Kernel(CodeGen):
|
|||
reduction_type: ReductionType,
|
||||
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
||||
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def scan(
|
||||
self,
|
||||
|
|
@ -1405,7 +1405,7 @@ class Kernel(CodeGen):
|
|||
],
|
||||
values: Tuple[CSEVariable, ...],
|
||||
) -> Tuple[CSEVariable, ...]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
|
|
@ -1418,11 +1418,11 @@ class Kernel(CodeGen):
|
|||
"""
|
||||
See [Note: Inductor bucketize op]
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def assert_function(self) -> str:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def indirect_assert(self, var, lower, upper, mask=None):
|
||||
if lower and upper:
|
||||
|
|
@ -1444,7 +1444,7 @@ class Kernel(CodeGen):
|
|||
return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
|
||||
|
||||
def index_to_str(self, index: sympy.Expr) -> str:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
# TODO: hoist this to top level
|
||||
|
|
@ -1737,4 +1737,4 @@ class KernelTemplate:
|
|||
Generates a ChoiceCaller instance from the given arguments.
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -2659,7 +2659,7 @@ class CppVecKernel(CppKernel):
|
|||
mean, m2, weight = reduction_project(reduction_type, next_value)
|
||||
return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def indirect_assert(self, var, lower, upper, mask=None):
|
||||
assert not mask, "do not support mask in indirect_indexing assertion"
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ class CutlassEVTEpilogueTypeFormatter:
|
|||
return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, ElementAcc, ElementAcc, RoundStyle>,{a}, {const_zero}>" # noqa: B950
|
||||
|
||||
def reduction(self, dtype, src_dtype, reduction_type, value):
|
||||
raise CUTLASSEVTOpNotImplementedError()
|
||||
raise CUTLASSEVTOpNotImplementedError
|
||||
|
||||
# Add more ops here...
|
||||
def getvalue(self, result) -> str:
|
||||
|
|
@ -354,7 +354,7 @@ class CutlassEVTEpilogueArgumentFormatter:
|
|||
return a
|
||||
|
||||
def reduction(self, dtype, src_dtype, reduction_type, value):
|
||||
raise CUTLASSEVTOpNotImplementedError()
|
||||
raise CUTLASSEVTOpNotImplementedError
|
||||
|
||||
def getvalue(self, result) -> str:
|
||||
return "{" + str(result) + "}"
|
||||
|
|
|
|||
|
|
@ -134,15 +134,15 @@ class AllocationTreeNode:
|
|||
|
||||
def get_live_ranges(self) -> LiveRanges:
|
||||
"""Aggregate LiveRanges for all objects below this in tree"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def get_size_hint(self) -> int:
|
||||
"""Number of bytes used for example inputs"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def get_symbolic_size(self) -> sympy.Expr:
|
||||
"""Number of bytes needed at runtime"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(self, pool, offset) -> AllocationTreeNode:
|
||||
"""Called after all allocations have been made"""
|
||||
|
|
|
|||
|
|
@ -1468,7 +1468,7 @@ class TritonKernel(Kernel):
|
|||
def add_range(i, expr):
|
||||
expr = sv.simplify(expr)
|
||||
if not sv.statically_known_multiple_of(remaining[i], expr):
|
||||
raise CantSplit()
|
||||
raise CantSplit
|
||||
# guard on the last item out
|
||||
remaining[i] = FloorDiv(remaining[i], expr)
|
||||
new_ranges[i].append(expr)
|
||||
|
|
@ -1501,7 +1501,7 @@ class TritonKernel(Kernel):
|
|||
if not sv.statically_known_multiple_of(
|
||||
size, remaining[current_group]
|
||||
):
|
||||
raise CantSplit()
|
||||
raise CantSplit
|
||||
size1 = remaining[current_group]
|
||||
size2 = FloorDiv(size, remaining[current_group])
|
||||
return_getters.append(
|
||||
|
|
|
|||
|
|
@ -1346,7 +1346,7 @@ class WrapperCodeGen(CodeGen):
|
|||
self.lines.append(LineContext(ctx))
|
||||
|
||||
def val_to_cpp_arg_str(self, type_, val) -> str:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def val_to_arg_str(self, s):
|
||||
from torch.utils._triton import dtype_to_string, has_triton_package
|
||||
|
|
|
|||
|
|
@ -954,10 +954,10 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
return self.add_tensor_constant(value, target)
|
||||
|
||||
def call_module(self, target, args, kwargs):
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
def call_method(self, target, args, kwargs):
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
def output(self, target, args, kwargs):
|
||||
result = super().output(target, args, kwargs)
|
||||
|
|
|
|||
|
|
@ -2252,7 +2252,7 @@ class View(GenericView):
|
|||
size_old = size_old * modulus
|
||||
V.graph.sizevars.guard_equals(size_new, size_old)
|
||||
else:
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
while stack_old:
|
||||
size_old = stack_old.pop()
|
||||
|
|
@ -2818,7 +2818,7 @@ class FlexibleLayout(Layout):
|
|||
"stride_ordered_for_memory_format, unsuppored memory_format: %s",
|
||||
memory_format,
|
||||
)
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def same_ordered(sizes, stride):
|
||||
|
|
@ -3666,16 +3666,16 @@ class ChoiceCaller:
|
|||
return do_bench(lambda: algo(*args, out=out))
|
||||
|
||||
def call_name(self) -> str:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def to_callable(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def hash_key(self) -> str:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def output_node(self) -> "TensorBox":
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
||||
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||
|
|
@ -3684,7 +3684,7 @@ class ChoiceCaller:
|
|||
|
||||
class TritonTemplateCallerBase(ChoiceCaller):
|
||||
def get_make_kernel_render(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
|
|
@ -4033,7 +4033,7 @@ class ExternKernel(InputsKernel):
|
|||
wrapper.writeline(origin_str)
|
||||
|
||||
def codegen(self, wrapper):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kernel_name(self):
|
||||
return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name
|
||||
|
|
@ -4157,7 +4157,7 @@ class ExternKernel(InputsKernel):
|
|||
offset,
|
||||
index,
|
||||
)
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
return ReinterpretView(
|
||||
data=x.data,
|
||||
|
|
|
|||
|
|
@ -1629,7 +1629,7 @@ def bernoulli_p(x, *args):
|
|||
# This shouldn't be called in general
|
||||
@register_lowering(aten._foobar)
|
||||
def _foobar(_):
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
|
||||
|
||||
@functools.lru_cache(1)
|
||||
|
|
|
|||
|
|
@ -218,7 +218,7 @@ class PatternExpr:
|
|||
def _match(
|
||||
self, node: torch.fx.Node, ctx: MatchContext
|
||||
) -> Union[Match, FailedMatch]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def match(self, node: torch.fx.Node) -> Union[Match, FailedMatch]:
|
||||
try:
|
||||
|
|
@ -361,7 +361,7 @@ class _TargetExpr(PatternExpr):
|
|||
return isinstance(self.users, Multiple) or self.users > 1
|
||||
|
||||
def find_anchor_nodes(self, ctx: MatchContext, searched):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _match_fns(self, node: torch.fx.Node):
|
||||
return (
|
||||
|
|
@ -803,7 +803,7 @@ class PatternEntry:
|
|||
extra_check: Callable[[Match], bool]
|
||||
|
||||
def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def register(self, pass_dicts, target=None, prepend=False):
|
||||
if target is None:
|
||||
|
|
@ -1507,7 +1507,7 @@ class PatternMatcherPass:
|
|||
|
||||
|
||||
def _not_implemented(*args, **kwargs) -> NoReturn:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def fx_to_pattern(
|
||||
|
|
|
|||
|
|
@ -2514,13 +2514,13 @@ class BaseScheduling:
|
|||
"""
|
||||
Check whether node1 and node2 can be vertically fused or not.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
||||
"""
|
||||
Check whether node1 and node2 can be horizontally fused or not.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
|
||||
"""
|
||||
|
|
@ -2535,7 +2535,7 @@ class BaseScheduling:
|
|||
"""
|
||||
Process the iteration sizes in case a transformation needs to be applied.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def codegen_template(
|
||||
self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode]
|
||||
|
|
@ -2546,19 +2546,19 @@ class BaseScheduling:
|
|||
This function is only available for triton now. If the third-party backend behaves as a sub-class
|
||||
of TritonScheduling, it can override it or reuse it.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]):
|
||||
"""
|
||||
Generate a kernel given a list of pre-fused nodes.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def codegen_sync(self):
|
||||
"""
|
||||
Generate synchronization code for the kernel. This method depends on the hardware characteristics.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def ready_to_flush(self) -> bool:
|
||||
"""
|
||||
|
|
@ -2571,14 +2571,14 @@ class BaseScheduling:
|
|||
"""
|
||||
Flush the generated kernel and python wrapper code to the source code file.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def benchmark_fused_nodes(self, nodes):
|
||||
"""
|
||||
Benchmark fused list of nodes and return the execution time
|
||||
in milliseconds on randomly generated inputs.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def get_fusion_pair_priority(self, node1, node2) -> int:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -949,11 +949,11 @@ class DeferredLineBase:
|
|||
|
||||
def __call__(self) -> Optional[str]:
|
||||
"""Returns either self.line or None to indicate the line has been 'unwritten'"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _new_line(self, line: str) -> DeferredLineBase:
|
||||
"""Returns a new deferred line with the same condition"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def with_prefix(self, prefix):
|
||||
return self._new_line(f"{prefix}{self.line}")
|
||||
|
|
|
|||
|
|
@ -217,7 +217,7 @@ def _split_helper_int(tensor, indices_or_sections, axis, strict=False):
|
|||
l, n = tensor.shape[axis], indices_or_sections
|
||||
|
||||
if n <= 0:
|
||||
raise ValueError()
|
||||
raise ValueError
|
||||
|
||||
if l % n == 0:
|
||||
num, sz = n, l // n
|
||||
|
|
|
|||
|
|
@ -512,7 +512,7 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
|
|||
if like is not None:
|
||||
raise NotImplementedError("'like' parameter is not supported.")
|
||||
if order != "K":
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
# a happy path
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ def maybe_copy_to(out, result, promote_scalar_result=False):
|
|||
maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result)
|
||||
)
|
||||
else:
|
||||
raise AssertionError() # We should never hit this path
|
||||
raise AssertionError # We should never hit this path
|
||||
|
||||
|
||||
def wrap_tensors(result):
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class OperatorBase:
|
|||
self.functorch_table = {}
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def has_kernel_for_dispatch_key(self, k):
|
||||
return k in self.py_kernels
|
||||
|
|
@ -165,7 +165,7 @@ class OperatorBase:
|
|||
return fn
|
||||
|
||||
def name(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
is_included_in_alias = torch._C._dispatch_is_included_in_alias
|
||||
|
|
|
|||
|
|
@ -6,27 +6,27 @@ class _StreamBase(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def wait_event(self, event):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def wait_stream(self, stream):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def record_event(self, event=None):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def query(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def synchronize(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __eq__(self, stream):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _EventBase(ABC):
|
||||
|
|
@ -34,12 +34,12 @@ class _EventBase(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def wait(self, stream=None):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def query(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def synchronize(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -354,7 +354,7 @@ def trace(data):
|
|||
elif e['action'] == 'oom':
|
||||
size = e['size']
|
||||
free = e['device_free']
|
||||
out.write(f'raise OutOfMemoryError() # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
|
||||
out.write(f'raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
|
||||
else:
|
||||
out.write(f'{e}\n')
|
||||
out.write(f"TOTAL MEM: {Bytes(count)}")
|
||||
|
|
|
|||
|
|
@ -465,7 +465,7 @@ class FSDPParam:
|
|||
return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)]
|
||||
elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD:
|
||||
if hasattr(self._sharded_local_tensor, "fsdp_pre_all_gather"):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
all_gather_input = _to_dtype_if_needed(
|
||||
cast(torch.Tensor, self._sharded_post_forward_param_data),
|
||||
self.param_dtype,
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class ParallelMode(ABC):
|
|||
TODO(@wanchaol): some of these arguments are not necessary for
|
||||
partitioning, remove the unnecessary ones later.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def transform_and_compile(self, gm: GraphModule) -> GraphModule:
|
||||
|
|
@ -51,7 +51,7 @@ class ParallelMode(ABC):
|
|||
the distributed environment.
|
||||
"""
|
||||
# TODO: add more necessary arguments to this interface.
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DataParallel(ParallelMode):
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ def _iterate_state_dict(
|
|||
not isinstance(companion_obj, (list, tuple))
|
||||
or len(companion_obj) != len(iter_object)
|
||||
):
|
||||
raise CompanionMismatch()
|
||||
raise CompanionMismatch
|
||||
|
||||
ret = [
|
||||
_iterate_state_dict(
|
||||
|
|
@ -437,7 +437,7 @@ def _check_state_dict_similarity(
|
|||
companion_obj: Any,
|
||||
) -> torch.Tensor:
|
||||
if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
|
||||
raise CompanionMismatch()
|
||||
raise CompanionMismatch
|
||||
return obj
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -439,7 +439,7 @@ class ElasticAgent(abc.ABC):
|
|||
Raises:
|
||||
Exception - any other failures NOT related to worker process
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
|
||||
|
|
@ -450,7 +450,7 @@ class ElasticAgent(abc.ABC):
|
|||
Implementors are encouraged (but not required) to return
|
||||
a defensive read-only copy.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SimpleElasticAgent(ElasticAgent):
|
||||
|
|
@ -477,7 +477,7 @@ class SimpleElasticAgent(ElasticAgent):
|
|||
This is according to worker spec for the worker group .
|
||||
Returns a map of ``local_rank`` to worker ``id``.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _stop_workers(self, worker_group: WorkerGroup) -> None:
|
||||
|
|
@ -487,7 +487,7 @@ class SimpleElasticAgent(ElasticAgent):
|
|||
``WorkerState``. That is, it must gracefully handle stopping
|
||||
non-existent workers, unhealthy (stuck) workers, etc.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
|
||||
|
|
@ -495,7 +495,7 @@ class SimpleElasticAgent(ElasticAgent):
|
|||
|
||||
This function also returns the new state of the worker group.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
|
||||
|
|
@ -504,7 +504,7 @@ class SimpleElasticAgent(ElasticAgent):
|
|||
Args:
|
||||
death_sig: Signal to send to the child process, SIGTERM is default
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _set_master_addr_port(
|
||||
|
|
|
|||
|
|
@ -458,7 +458,7 @@ class PContext(abc.ABC):
|
|||
@abc.abstractmethod
|
||||
def _start(self) -> None:
|
||||
"""Start processes using strategy defined in a particular context."""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _poll(self) -> Optional[RunProcsResult]:
|
||||
|
|
@ -469,7 +469,7 @@ class PContext(abc.ABC):
|
|||
successfully or any process fails. Returns ``None`` if
|
||||
all processes are still running.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
|
||||
"""
|
||||
|
|
@ -514,7 +514,7 @@ class PContext(abc.ABC):
|
|||
@abc.abstractmethod
|
||||
def pids(self) -> Dict[int, int]:
|
||||
"""Return pids of processes mapped by their respective local_ranks."""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
|
||||
|
|
@ -522,7 +522,7 @@ class PContext(abc.ABC):
|
|||
Terminates all processes managed by this context and cleans up any
|
||||
meta resources (e.g. redirect, error_file files).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def close(
|
||||
self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
|
||||
|
|
|
|||
|
|
@ -655,10 +655,10 @@ class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
|
|||
continue
|
||||
|
||||
if action == _Action.ERROR_CLOSED:
|
||||
raise RendezvousClosedError()
|
||||
raise RendezvousClosedError
|
||||
|
||||
if action == _Action.ERROR_TIMEOUT:
|
||||
raise RendezvousTimeoutError()
|
||||
raise RendezvousTimeoutError
|
||||
|
||||
if action == _Action.SYNC:
|
||||
# Delay the execution by one second to avoid overloading the
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ class EtcdRendezvous:
|
|||
self._rendezvous_deadline = time.time() + self._timeout
|
||||
while True:
|
||||
if time.time() > self._rendezvous_deadline:
|
||||
raise RendezvousTimeoutError()
|
||||
raise RendezvousTimeoutError
|
||||
|
||||
logger.info("Attempting to join next rendezvous")
|
||||
try:
|
||||
|
|
@ -340,17 +340,17 @@ class EtcdRendezvous:
|
|||
logger.info("Observed existing rendezvous state: %s", state)
|
||||
|
||||
if state["status"] == "closed":
|
||||
raise RendezvousClosedError()
|
||||
raise RendezvousClosedError
|
||||
|
||||
if state["status"] == "joinable":
|
||||
return self.join_phase(state["version"])
|
||||
|
||||
if state["status"] == "final":
|
||||
self.handle_existing_rendezvous(state["version"])
|
||||
raise EtcdRendezvousRetryImmediately()
|
||||
raise EtcdRendezvousRetryImmediately
|
||||
|
||||
self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1)
|
||||
raise EtcdRendezvousRetryableFailure()
|
||||
raise EtcdRendezvousRetryableFailure
|
||||
|
||||
def join_phase(self, expected_version):
|
||||
"""
|
||||
|
|
@ -632,7 +632,7 @@ class EtcdRendezvous:
|
|||
active_version, state = self.get_rdzv_state()
|
||||
|
||||
if state["status"] != "final" or state["version"] != expected_version:
|
||||
raise EtcdRendezvousRetryImmediately()
|
||||
raise EtcdRendezvousRetryImmediately
|
||||
|
||||
# Increment counter to signal an additional waiting worker.
|
||||
state["num_workers_waiting"] += 1
|
||||
|
|
@ -714,7 +714,7 @@ class EtcdRendezvous:
|
|||
pass
|
||||
|
||||
if time.time() > self._rendezvous_deadline:
|
||||
raise RendezvousTimeoutError()
|
||||
raise RendezvousTimeoutError
|
||||
active_version, state = self.get_rdzv_state()
|
||||
|
||||
def handle_join_last_call(self, expected_version, deadline):
|
||||
|
|
@ -832,7 +832,7 @@ class EtcdRendezvous:
|
|||
pass
|
||||
|
||||
if time.time() > self._rendezvous_deadline:
|
||||
raise RendezvousTimeoutError()
|
||||
raise RendezvousTimeoutError
|
||||
|
||||
# Unfortunately, we have to do another fetch in order to get last etcd_index.
|
||||
return self.get_rdzv_state()
|
||||
|
|
|
|||
|
|
@ -963,7 +963,7 @@ class _PatchedFn(NamedTuple):
|
|||
orig_fn: Any
|
||||
|
||||
def revert(self):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _PatchedFnSetItem(_PatchedFn):
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
|
|||
#
|
||||
class SymDispatchMode:
|
||||
def __sym_dispatch__(self, func, types, args, kwargs):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
global SYM_FUNCTION_MODE
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class OperatorSupportBase(abc.ABC):
|
|||
def is_node_supported(
|
||||
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
|
||||
) -> bool:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
|
|
|||
|
|
@ -476,7 +476,7 @@ if _enabled:
|
|||
# RecursiveScriptClass.
|
||||
def forward_magic_method(self, method_name, *args, **kwargs):
|
||||
if not self._c._has_method(method_name):
|
||||
raise TypeError()
|
||||
raise TypeError
|
||||
|
||||
self_method = self.__getattr__(method_name)
|
||||
return self_method(*args, **kwargs)
|
||||
|
|
@ -865,7 +865,7 @@ if _enabled:
|
|||
if getattr(self_method, "__func__", None) == getattr(
|
||||
RecursiveScriptModule, method_name
|
||||
):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
return self_method(*args, **kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
|||
|
|
@ -1184,7 +1184,7 @@ class _LazyConvXdMixin(LazyModuleMixin):
|
|||
# Function to return the number of spatial dims expected for inputs to the module.
|
||||
# This is expected to be implemented by subclasses.
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter
|
||||
|
|
|
|||
|
|
@ -264,7 +264,7 @@ class Invocation:
|
|||
# TODO: Implement this.
|
||||
# Tracks top level call arguments and diagnostic options.
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
|||
|
|
@ -1910,7 +1910,7 @@ class TorchFunctionMode:
|
|||
pass
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
_push_mode(self)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class ErrorMeta(Exception):
|
|||
super().__init__(
|
||||
"If you are a user and see this message during normal operation "
|
||||
"please file an issue at https://github.com/pytorch/pytorch/issues. "
|
||||
"If you are a developer and working on the comparison functions, please `raise ErrorMeta().to_error()` "
|
||||
"If you are a developer and working on the comparison functions, please `raise ErrorMeta.to_error()` "
|
||||
"for user facing errors."
|
||||
)
|
||||
self.type = type
|
||||
|
|
@ -336,7 +336,7 @@ class Pair(abc.ABC):
|
|||
|
||||
@staticmethod
|
||||
def _inputs_not_supported() -> NoReturn:
|
||||
raise UnsupportedInputs()
|
||||
raise UnsupportedInputs
|
||||
|
||||
@staticmethod
|
||||
def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, Tuple[Type, ...]]):
|
||||
|
|
@ -1217,7 +1217,7 @@ def not_close_error_metas(
|
|||
)
|
||||
except ErrorMeta as error_meta:
|
||||
# Explicitly raising from None to hide the internal traceback
|
||||
raise error_meta.to_error() from None
|
||||
raise error_meta.to_error() from None # noqa: RSE102
|
||||
|
||||
error_metas: List[ErrorMeta] = []
|
||||
for pair in pairs:
|
||||
|
|
|
|||
|
|
@ -9390,7 +9390,7 @@ class DistributedTest:
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise RuntimeError()
|
||||
raise RuntimeError
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self, device):
|
||||
|
|
@ -9534,7 +9534,7 @@ class DistributedTest:
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise RuntimeError()
|
||||
raise RuntimeError
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class TorchDispatchMode:
|
|||
self.old_dispatch_mode_flag = False
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
global _is_in_torch_dispatch_mode
|
||||
|
|
|
|||
|
|
@ -1085,7 +1085,7 @@ class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
|
|||
if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
|
||||
target_frame.weak_holders
|
||||
):
|
||||
raise _StopRecomputationError()
|
||||
raise _StopRecomputationError
|
||||
# See Rule 6: [ retain_graph is True ] above
|
||||
return x.detach()
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ class _BaseDatasetFetcher:
|
|||
self.drop_last = drop_last
|
||||
|
||||
def fetch(self, possibly_batched_index):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _IterableDatasetFetcher(_BaseDatasetFetcher):
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class Capture:
|
|||
if attrname == 'kwarg' or attrname == 'kwargs':
|
||||
raise Exception('no kwargs!')
|
||||
if attrname in ['__deepcopy__']:
|
||||
raise AttributeError()
|
||||
raise AttributeError
|
||||
result = CaptureGetAttr(self, attrname, ctx=self.ctx)
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class Sampler(Generic[T_co]):
|
|||
# Calling `len(subclass_instance)` raises:
|
||||
# TypeError: 'NotImplementedType' object cannot be interpreted as an integer
|
||||
#
|
||||
# + `raise NotImplementedError()`:
|
||||
# + `raise NotImplementedError`:
|
||||
# This prevents triggering some fallback behavior. E.g., the built-in
|
||||
# `list(X)` tries to call `len(X)` first, and executes a different code
|
||||
# path if the method is not found or `NotImplemented` is returned, while
|
||||
|
|
|
|||
|
|
@ -507,7 +507,7 @@ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
|
|||
elif k is UfuncKey.CPUVector:
|
||||
compute_t = VectorizedCType(BaseCType(scalar_t))
|
||||
else:
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
|
||||
if k not in inner_ufunc_sigs:
|
||||
inner_ufunc_sigs[k] = UfuncSignature(
|
||||
|
|
|
|||
|
|
@ -1151,7 +1151,7 @@ def compute_cpp_argument_yaml(
|
|||
arg["default"] = cpp_a.default
|
||||
return arg
|
||||
elif isinstance(cpp_a.argument, SelfArgument):
|
||||
raise AssertionError()
|
||||
raise AssertionError
|
||||
elif isinstance(cpp_a.argument, Argument):
|
||||
return compute_argument_yaml(
|
||||
cpp_a.argument,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user