[dynamo][exception] Support raise exception from None (#134028)

Fixes https://github.com/pytorch/pytorch/issues/132362

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134028
Approved by: https://github.com/yanboliang
This commit is contained in:
Animesh Jain 2024-08-20 13:47:12 -07:00 committed by PyTorch MergeBot
parent bd0db490bf
commit 0d79f67a25
2 changed files with 74 additions and 28 deletions

View File

@ -330,6 +330,39 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertEqual(ref, res)
def test_raise_from_None(self):
# Inspired from os.environ
class MyMapping:
def __init__(self, d):
self._d = d
def __getitem__(self, key):
try:
value = self._d[key]
except KeyError:
raise KeyError(key) from None
return value
d = MyMapping({"a": 10, "b": 20})
def mapping_get(obj, key, value=None):
try:
return obj.__getitem__(key)
except KeyError:
return value
def fn(x, d, key):
x = torch.sin(x + 1)
return x, mapping_get(d, key)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.rand(2, 3)
ref = fn(x, d, "m")
res = opt_fn(x, d, "m")
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1355,34 +1355,48 @@ class InstructionTranslatorBase(
self.push(ConstantVariable.create(None))
self.jump(inst)
def _raise_exception_variable(self, inst):
val = self.pop()
# User can raise exception in 2 ways
# 1) raise exception type - raise NotImplementedError
# 2) raise execption instance - raise NotImplemetedError("foo")
# 1) when user raises exception type
if isinstance(val, variables.BuiltinVariable):
# Create the instance of the exception type
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
val = val.call_function(self, [], {}) # type: ignore[arg-type]
# Save the exception in a global data structure
self.exn_vt_stack.append(val)
# 2) when user raises exception instance
if isinstance(val, variables.ExceptionVariable):
if val.exc_type is StopIteration:
# StopIteration is used to find the end of iteration while tracing __next__
raise exc.ObservedUserStopIteration(f"raised exception {val}")
raise exc.ObservedException(f"raised exception {val}")
unimplemented(f"raise {exc}")
def RAISE_VARARGS(self, inst):
if inst.arg == 0:
unimplemented("re-raise")
elif inst.arg == 1:
val = self.pop()
# User can raise exception in 2 ways
# 1) raise exception type - raise NotImplementedError
# 2) raise execption instance - raise NotImplemetedError("foo")
# 1) when user raises exception type
if isinstance(val, variables.BuiltinVariable):
# Create the instance of the exception type
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
val = val.call_function(self, [], {}) # type: ignore[arg-type]
# Save the exception in a global data structure
self.exn_vt_stack.append(val)
# 2) when user raises exception instance
if isinstance(val, variables.ExceptionVariable):
if val.exc_type is StopIteration:
# StopIteration is used to find the end of iteration while tracing __next__
raise exc.ObservedUserStopIteration(f"raised exception {val}")
raise exc.ObservedException(f"raised exception {val}")
unimplemented(f"raise {exc}")
self._raise_exception_variable(inst)
else:
# Support raise .. from None ... Dynamo does not track __cause__ and other attributes of exception. So we
# ignore `from None` part.
from_vt = self.pop()
if isinstance(from_vt, ConstantVariable) and from_vt.value is None:
self._raise_exception_variable(inst)
unimplemented("raise ... from ...")
def RERAISE(self, inst):
if sys.version_info >= (3, 11):
# RERAISE is currently supported in a narrow case of `raise ... from None`
self._raise_exception_variable(inst)
unimplemented("RERAISE")
def exception_handler(self, raised_exception):
if sys.version_info >= (3, 11):
exn_tab_entry = self.current_instruction.exn_tab_entry
@ -1396,10 +1410,6 @@ class InstructionTranslatorBase(
# 2) if 'lasti' is true, then push the offset that the exception was raised at
if exn_tab_entry.lasti:
# This is untested. Any test that tests this end-to-end
# requires supporting more bytecodes. Therefore graph
# breaking for now.
unimplemented("lasti=True while exception handling")
self.push(
variables.ConstantVariable(self.current_instruction.offset)
)
@ -1431,9 +1441,12 @@ class InstructionTranslatorBase(
# https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
self.popn(3)
if len(self.block_stack) == 0:
unimplemented(
"exception is raised when block stack " "is empty"
)
# No handler found in this frame. Bubble the exception to the parent
# instruction translater.
self.stack.clear()
if type(self) is InstructionTranslator:
raise Unsupported("Observed exception")
raise raised_exception
block_stack_entry = self.block_stack.pop()
if block_stack_entry.inst.opname != "SETUP_FINALLY":