mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Fix python_type for UserDefinedClassExceptionVariable (#166251)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166251 Approved by: https://github.com/Lucaskabela
This commit is contained in:
parent
61bad3c1ea
commit
610c09f8f4
|
|
@ -889,20 +889,26 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
assert z == 1
|
||||
|
||||
def test_user_defined_exception_variable(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
z = 0
|
||||
try:
|
||||
raise CustomException
|
||||
except ValueError:
|
||||
z = 1
|
||||
except CustomException:
|
||||
except CustomException as e:
|
||||
# trying to call python_type on the
|
||||
# UserDefinedExceptionClassVariable
|
||||
cls = type(e)
|
||||
if type(cls) is type:
|
||||
t = t + 1
|
||||
z = 2
|
||||
assert z == 2
|
||||
return t.sin()
|
||||
|
||||
t = torch.randn(2)
|
||||
fn(t)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(t), opt_fn(t))
|
||||
|
||||
def test_user_defined_exception_with_args(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
|
|
|
|||
|
|
@ -872,9 +872,6 @@ class UserDefinedExceptionClassVariable(UserDefinedClassVariable):
|
|||
def fn(self):
|
||||
return self.value
|
||||
|
||||
def python_type(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class NO_SUCH_SUBOBJ:
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user