mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @jansel Pull Request resolved: https://github.com/pytorch/pytorch/pull/87731 Approved by: https://github.com/yanboliang, https://github.com/ezyang, https://github.com/anijain2305, https://github.com/jansel
73 lines
1.8 KiB
Python
73 lines
1.8 KiB
Python
import os
|
|
import textwrap
|
|
|
|
from .utils import counters
|
|
|
|
|
|
class TorchDynamoException(RuntimeError):
|
|
pass
|
|
|
|
|
|
class InternalTorchDynamoError(TorchDynamoException):
|
|
pass
|
|
|
|
|
|
class RestartAnalysis(TorchDynamoException):
|
|
pass
|
|
|
|
|
|
class SkipFrame(TorchDynamoException):
|
|
pass
|
|
|
|
|
|
class TorchRuntimeError(TorchDynamoException):
|
|
pass
|
|
|
|
|
|
class ResetRequired(TorchDynamoException):
|
|
def __init__(self):
|
|
super(ResetRequired, self).__init__(
|
|
textwrap.dedent(
|
|
"""
|
|
Must call `torch._dynamo.reset()` before changing backends. Detected two calls to
|
|
`torch._dynamo.optimize(...)` with a different backend compiler arguments.
|
|
"""
|
|
)
|
|
)
|
|
|
|
|
|
class BackendCompilerFailed(TorchDynamoException):
|
|
def __init__(self, backend_fn, inner_exception):
|
|
self.backend_name = getattr(backend_fn, "__name__", "?")
|
|
self.inner_exception = inner_exception
|
|
msg = f"{self.backend_name} raised {type(inner_exception).__name__}: {inner_exception}"
|
|
super().__init__(msg)
|
|
|
|
|
|
class Unsupported(TorchDynamoException):
|
|
def __init__(self, msg):
|
|
super(Unsupported, self).__init__(msg)
|
|
self.real_stack = []
|
|
self.msg = msg
|
|
self.category = None
|
|
self.add_to_stats()
|
|
|
|
def remove_from_stats(self):
|
|
counters[self.category][self.msg] -= 1
|
|
if counters[self.category][self.msg] <= 0:
|
|
del counters[self.category][self.msg]
|
|
|
|
def add_to_stats(self, category="unimplemented"):
|
|
self.category = category
|
|
counters[category][self.msg] += 1
|
|
|
|
|
|
def unimplemented(msg: str):
|
|
assert msg != os.environ.get("BREAK", False)
|
|
raise Unsupported(msg)
|
|
|
|
|
|
def warning(msg: str):
|
|
counters["warnings"][msg] += 1
|
|
assert msg != os.environ.get("BREAK", False)
|