pytorch/torch/_dynamo/exc.py
Horace He 12dd877395 Fix all references to torchdynamo from the merge (#87731)
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87731
Approved by: https://github.com/yanboliang, https://github.com/ezyang, https://github.com/anijain2305, https://github.com/jansel
2022-10-31 06:51:07 +00:00

73 lines
1.8 KiB
Python

import os
import textwrap
from .utils import counters
class TorchDynamoException(RuntimeError):
pass
class InternalTorchDynamoError(TorchDynamoException):
pass
class RestartAnalysis(TorchDynamoException):
pass
class SkipFrame(TorchDynamoException):
pass
class TorchRuntimeError(TorchDynamoException):
pass
class ResetRequired(TorchDynamoException):
def __init__(self):
super(ResetRequired, self).__init__(
textwrap.dedent(
"""
Must call `torch._dynamo.reset()` before changing backends. Detected two calls to
`torch._dynamo.optimize(...)` with a different backend compiler arguments.
"""
)
)
class BackendCompilerFailed(TorchDynamoException):
def __init__(self, backend_fn, inner_exception):
self.backend_name = getattr(backend_fn, "__name__", "?")
self.inner_exception = inner_exception
msg = f"{self.backend_name} raised {type(inner_exception).__name__}: {inner_exception}"
super().__init__(msg)
class Unsupported(TorchDynamoException):
def __init__(self, msg):
super(Unsupported, self).__init__(msg)
self.real_stack = []
self.msg = msg
self.category = None
self.add_to_stats()
def remove_from_stats(self):
counters[self.category][self.msg] -= 1
if counters[self.category][self.msg] <= 0:
del counters[self.category][self.msg]
def add_to_stats(self, category="unimplemented"):
self.category = category
counters[category][self.msg] += 1
def unimplemented(msg: str):
assert msg != os.environ.get("BREAK", False)
raise Unsupported(msg)
def warning(msg: str):
counters["warnings"][msg] += 1
assert msg != os.environ.get("BREAK", False)