mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153966 Approved by: https://github.com/Lucaskabela
142 lines
4.5 KiB
Diff
142 lines
4.5 KiB
Diff
diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py
|
|
index e599b02c17d..057b6ec01b9 100644
|
|
--- a/test/dynamo/cpython/3_13/test_baseexception.py
|
|
+++ b/test/dynamo/cpython/3_13/test_baseexception.py
|
|
@@ -1,10 +1,64 @@
|
|
+# ======= BEGIN Dynamo patch =======
|
|
+# Owner(s): ["module: dynamo"]
|
|
+
|
|
+# ruff: noqa
|
|
+# flake8: noqa
|
|
+
|
|
+# Test copied from
|
|
+# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_baseexception.py
|
|
+
|
|
+import sys
|
|
+import torch
|
|
+import torch._dynamo.test_case
|
|
+import unittest
|
|
+from torch._dynamo.test_case import CPythonTestCase
|
|
+from torch.testing._internal.common_utils import run_tests
|
|
+
|
|
+__TestCase = CPythonTestCase
|
|
+
|
|
+
|
|
+# redirect import statements
|
|
+import sys
|
|
+import importlib.abc
|
|
+
|
|
+redirect_imports = (
|
|
+ "test.mapping_tests",
|
|
+ "test.typinganndata",
|
|
+ "test.test_grammar",
|
|
+ "test.test_math",
|
|
+ "test.test_iter",
|
|
+ "test.typinganndata.ann_module",
|
|
+)
|
|
+
|
|
+class RedirectImportFinder(importlib.abc.MetaPathFinder):
|
|
+ def find_spec(self, fullname, path, target=None):
|
|
+ # Check if the import is the problematic one
|
|
+ if fullname in redirect_imports:
|
|
+ try:
|
|
+ # Attempt to import the standalone module
|
|
+ name = fullname.removeprefix("test.")
|
|
+ r = importlib.import_module(name)
|
|
+ # Redirect the module in sys.modules
|
|
+ sys.modules[fullname] = r
|
|
+ # Return a module spec from the found module
|
|
+ return importlib.util.find_spec(name)
|
|
+ except ImportError:
|
|
+ return None
|
|
+ return None
|
|
+
|
|
+# Add the custom finder to sys.meta_path
|
|
+sys.meta_path.insert(0, RedirectImportFinder())
|
|
+
|
|
+
|
|
+# ======= END DYNAMO PATCH =======
|
|
+
|
|
import unittest
|
|
import builtins
|
|
import os
|
|
from platform import system as platform_system
|
|
|
|
|
|
-class ExceptionClassTests(unittest.TestCase):
|
|
+class ExceptionClassTests(__TestCase):
|
|
|
|
"""Tests for anything relating to exception objects themselves (e.g.,
|
|
inheritance hierarchy)"""
|
|
@@ -78,9 +132,6 @@ class ExceptionClassTests(unittest.TestCase):
|
|
last_depth = depth
|
|
finally:
|
|
inheritance_tree.close()
|
|
-
|
|
- # Underscore-prefixed (private) exceptions don't need to be documented
|
|
- exc_set = set(e for e in exc_set if not e.startswith('_'))
|
|
self.assertEqual(len(exc_set), 0, "%s not accounted for" % exc_set)
|
|
|
|
interface_tests = ("length", "args", "str", "repr")
|
|
@@ -122,12 +173,13 @@ class ExceptionClassTests(unittest.TestCase):
|
|
# in PyObject_SetAttr.
|
|
import gc
|
|
d = {}
|
|
- class HashThisKeyWillClearTheDict(str):
|
|
- def __hash__(self) -> int:
|
|
- d.clear()
|
|
- return super().__hash__()
|
|
- class Value(str):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class HashThisKeyWillClearTheDict(str):
|
|
+ def __hash__(self) -> int:
|
|
+ d.clear()
|
|
+ return super().__hash__()
|
|
+ class Value(str):
|
|
+ pass
|
|
exc = Exception()
|
|
|
|
d[HashThisKeyWillClearTheDict()] = Value() # refcount of Value() is 1 now
|
|
@@ -142,7 +194,7 @@ class ExceptionClassTests(unittest.TestCase):
|
|
gc.collect()
|
|
|
|
|
|
-class UsageTests(unittest.TestCase):
|
|
+class UsageTests(__TestCase):
|
|
|
|
"""Test usage of exceptions"""
|
|
|
|
@@ -182,8 +234,9 @@ class UsageTests(unittest.TestCase):
|
|
# BaseException; the ability was not possible until BaseException's
|
|
# introduction so no need to support new-style objects that do not
|
|
# inherit from it.
|
|
- class NewStyleClass(object):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class NewStyleClass(object):
|
|
+ pass
|
|
self.raise_fails(NewStyleClass)
|
|
self.raise_fails(NewStyleClass())
|
|
|
|
@@ -194,8 +247,9 @@ class UsageTests(unittest.TestCase):
|
|
def test_catch_non_BaseException(self):
|
|
# Trying to catch an object that does not inherit from BaseException
|
|
# is not allowed.
|
|
- class NonBaseException(object):
|
|
- pass
|
|
+ with torch._dynamo.error_on_graph_break(False):
|
|
+ class NonBaseException(object):
|
|
+ pass
|
|
self.catch_fails(NonBaseException)
|
|
self.catch_fails(NonBaseException())
|
|
|
|
@@ -208,5 +262,5 @@ class UsageTests(unittest.TestCase):
|
|
self.catch_fails("spam")
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
- unittest.main()
|
|
+if __name__ == "__main__":
|
|
+ run_tests()
|