mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][exception] Support raise exception from None (#134028)
Fixes https://github.com/pytorch/pytorch/issues/132362 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134028 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
bd0db490bf
commit
0d79f67a25
|
|
@ -330,6 +330,39 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_raise_from_None(self):
|
||||
# Inspired from os.environ
|
||||
class MyMapping:
|
||||
def __init__(self, d):
|
||||
self._d = d
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
value = self._d[key]
|
||||
except KeyError:
|
||||
raise KeyError(key) from None
|
||||
return value
|
||||
|
||||
d = MyMapping({"a": 10, "b": 20})
|
||||
|
||||
def mapping_get(obj, key, value=None):
|
||||
try:
|
||||
return obj.__getitem__(key)
|
||||
except KeyError:
|
||||
return value
|
||||
|
||||
def fn(x, d, key):
|
||||
x = torch.sin(x + 1)
|
||||
return x, mapping_get(d, key)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
|
||||
x = torch.rand(2, 3)
|
||||
ref = fn(x, d, "m")
|
||||
res = opt_fn(x, d, "m")
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -1355,34 +1355,48 @@ class InstructionTranslatorBase(
|
|||
self.push(ConstantVariable.create(None))
|
||||
self.jump(inst)
|
||||
|
||||
def _raise_exception_variable(self, inst):
|
||||
val = self.pop()
|
||||
# User can raise exception in 2 ways
|
||||
# 1) raise exception type - raise NotImplementedError
|
||||
# 2) raise execption instance - raise NotImplemetedError("foo")
|
||||
|
||||
# 1) when user raises exception type
|
||||
if isinstance(val, variables.BuiltinVariable):
|
||||
# Create the instance of the exception type
|
||||
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
|
||||
val = val.call_function(self, [], {}) # type: ignore[arg-type]
|
||||
|
||||
# Save the exception in a global data structure
|
||||
self.exn_vt_stack.append(val)
|
||||
|
||||
# 2) when user raises exception instance
|
||||
if isinstance(val, variables.ExceptionVariable):
|
||||
if val.exc_type is StopIteration:
|
||||
# StopIteration is used to find the end of iteration while tracing __next__
|
||||
raise exc.ObservedUserStopIteration(f"raised exception {val}")
|
||||
raise exc.ObservedException(f"raised exception {val}")
|
||||
unimplemented(f"raise {exc}")
|
||||
|
||||
def RAISE_VARARGS(self, inst):
|
||||
if inst.arg == 0:
|
||||
unimplemented("re-raise")
|
||||
elif inst.arg == 1:
|
||||
val = self.pop()
|
||||
# User can raise exception in 2 ways
|
||||
# 1) raise exception type - raise NotImplementedError
|
||||
# 2) raise execption instance - raise NotImplemetedError("foo")
|
||||
|
||||
# 1) when user raises exception type
|
||||
if isinstance(val, variables.BuiltinVariable):
|
||||
# Create the instance of the exception type
|
||||
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
|
||||
val = val.call_function(self, [], {}) # type: ignore[arg-type]
|
||||
|
||||
# Save the exception in a global data structure
|
||||
self.exn_vt_stack.append(val)
|
||||
|
||||
# 2) when user raises exception instance
|
||||
if isinstance(val, variables.ExceptionVariable):
|
||||
if val.exc_type is StopIteration:
|
||||
# StopIteration is used to find the end of iteration while tracing __next__
|
||||
raise exc.ObservedUserStopIteration(f"raised exception {val}")
|
||||
raise exc.ObservedException(f"raised exception {val}")
|
||||
unimplemented(f"raise {exc}")
|
||||
self._raise_exception_variable(inst)
|
||||
else:
|
||||
# Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we
|
||||
# ignore `from None` part.
|
||||
from_vt = self.pop()
|
||||
if isinstance(from_vt, ConstantVariable) and from_vt.value is None:
|
||||
self._raise_exception_variable(inst)
|
||||
unimplemented("raise ... from ...")
|
||||
|
||||
def RERAISE(self, inst):
|
||||
if sys.version_info >= (3, 11):
|
||||
# RERAISE is currently supported in a narrow case of `raise ... from None`
|
||||
self._raise_exception_variable(inst)
|
||||
unimplemented("RERAISE")
|
||||
|
||||
def exception_handler(self, raised_exception):
|
||||
if sys.version_info >= (3, 11):
|
||||
exn_tab_entry = self.current_instruction.exn_tab_entry
|
||||
|
|
@ -1396,10 +1410,6 @@ class InstructionTranslatorBase(
|
|||
|
||||
# 2) if 'lasti' is true, then push the offset that the exception was raised at
|
||||
if exn_tab_entry.lasti:
|
||||
# This is untested. Any test that tests this end-to-end
|
||||
# requires supporting more bytecodes. Therefore graph
|
||||
# breaking for now.
|
||||
unimplemented("lasti=True while exception handling")
|
||||
self.push(
|
||||
variables.ConstantVariable(self.current_instruction.offset)
|
||||
)
|
||||
|
|
@ -1431,9 +1441,12 @@ class InstructionTranslatorBase(
|
|||
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
|
||||
self.popn(3)
|
||||
if len(self.block_stack) == 0:
|
||||
unimplemented(
|
||||
"exception is raised when block stack " "is empty"
|
||||
)
|
||||
# No handler found in this frame. Bubble the exception to the parent
|
||||
# instruction translater.
|
||||
self.stack.clear()
|
||||
if type(self) is InstructionTranslator:
|
||||
raise Unsupported("Observed exception")
|
||||
raise raised_exception
|
||||
block_stack_entry = self.block_stack.pop()
|
||||
|
||||
if block_stack_entry.inst.opname != "SETUP_FINALLY":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user