mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Fixes for CPython int/float tests (#155978)"
This reverts commit 216bd6091e.
Reverted https://github.com/pytorch/pytorch/pull/155978 on behalf of https://github.com/huydhn due to Some tests are still failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/155978#issuecomment-3014185210))
This commit is contained in:
parent
7c51619e7f
commit
0decd966af
|
|
@ -1,5 +1,5 @@
|
||||||
diff --git a/test/dynamo/cpython/3_13/test_int.py b/test/dynamo/cpython/3_13/test_int.py
|
diff --git a/test/dynamo/cpython/3_13/test_int.py b/test/dynamo/cpython/3_13/test_int.py
|
||||||
index 48825f46911..4ab200372ea 100644
|
index 48825f46911..ac7aeacbc01 100644
|
||||||
--- a/test/dynamo/cpython/3_13/test_int.py
|
--- a/test/dynamo/cpython/3_13/test_int.py
|
||||||
+++ b/test/dynamo/cpython/3_13/test_int.py
|
+++ b/test/dynamo/cpython/3_13/test_int.py
|
||||||
@@ -1,13 +1,137 @@
|
@@ -1,13 +1,137 @@
|
||||||
|
|
@ -153,15 +153,7 @@ index 48825f46911..4ab200372ea 100644
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
self.assertEqual(int(314), 314)
|
self.assertEqual(int(314), 314)
|
||||||
@@ -566,6 +690,7 @@ class IntTestCases(unittest.TestCase):
|
@@ -607,7 +731,7 @@ class IntTestCases(unittest.TestCase):
|
||||||
self.assertEqual(n, 1)
|
|
||||||
self.assertIs(type(n), IntSubclass)
|
|
||||||
|
|
||||||
+ @skipIfTorchDynamo("flaky under dynamo")
|
|
||||||
def test_error_message(self):
|
|
||||||
def check(s, base=None):
|
|
||||||
with self.assertRaises(ValueError,
|
|
||||||
@@ -607,7 +732,7 @@ class IntTestCases(unittest.TestCase):
|
|
||||||
self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807)
|
self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -170,7 +162,7 @@ index 48825f46911..4ab200372ea 100644
|
||||||
|
|
||||||
int_class = int # Override this in subclasses to reuse the suite.
|
int_class = int # Override this in subclasses to reuse the suite.
|
||||||
|
|
||||||
@@ -818,7 +943,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
|
@@ -818,7 +942,7 @@ class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
|
||||||
int_class = IntSubclass
|
int_class = IntSubclass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -179,7 +171,7 @@ index 48825f46911..4ab200372ea 100644
|
||||||
# Tests of the functions in _pylong.py. Those get used when the
|
# Tests of the functions in _pylong.py. Those get used when the
|
||||||
# number of digits in the input values are large enough.
|
# number of digits in the input values are large enough.
|
||||||
|
|
||||||
@@ -922,4 +1047,4 @@ class PyLongModuleTests(unittest.TestCase):
|
@@ -922,4 +1046,4 @@ class PyLongModuleTests(unittest.TestCase):
|
||||||
bits <<= 1
|
bits <<= 1
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import torch
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
import unittest
|
import unittest
|
||||||
from torch._dynamo.test_case import CPythonTestCase
|
from torch._dynamo.test_case import CPythonTestCase
|
||||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
__TestCase = CPythonTestCase
|
__TestCase = CPythonTestCase
|
||||||
|
|
||||||
|
|
@ -690,7 +690,6 @@ class IntTestCases(__TestCase):
|
||||||
self.assertEqual(n, 1)
|
self.assertEqual(n, 1)
|
||||||
self.assertIs(type(n), IntSubclass)
|
self.assertIs(type(n), IntSubclass)
|
||||||
|
|
||||||
@skipIfTorchDynamo("flaky under dynamo")
|
|
||||||
def test_error_message(self):
|
def test_error_message(self):
|
||||||
def check(s, base=None):
|
def check(s, base=None):
|
||||||
with self.assertRaises(ValueError,
|
with self.assertRaises(ValueError,
|
||||||
|
|
|
||||||
|
|
@ -734,7 +734,6 @@ class TestVmapAPI(TestCase):
|
||||||
# warning, not a warning from the vmap fallback path.
|
# warning, not a warning from the vmap fallback path.
|
||||||
self.assertEqual(len(wa), 1)
|
self.assertEqual(len(wa), 1)
|
||||||
|
|
||||||
@skipIfTorchDynamo("Flaky test")
|
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
def test_fallback_warns_when_warnings_are_enabled(self):
|
def test_fallback_warns_when_warnings_are_enabled(self):
|
||||||
# NB: One day we will implement a batching rule for torch.atan2.
|
# NB: One day we will implement a batching rule for torch.atan2.
|
||||||
|
|
|
||||||
|
|
@ -374,7 +374,7 @@ def raise_observed_exception(
|
||||||
# stack and raise the exception.
|
# stack and raise the exception.
|
||||||
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type]
|
exception_vt = BuiltinVariable(exc_type).call_function(tx, args or [], kwargs or {}) # type: ignore[arg-type]
|
||||||
tx.exn_vt_stack.set_current_exception(exception_vt)
|
tx.exn_vt_stack.set_current_exception(exception_vt)
|
||||||
raise get_dynamo_observed_exception(exc_type)
|
raise observed_exception_map[exc_type]
|
||||||
|
|
||||||
|
|
||||||
def handle_observed_exception(tx: Any) -> None:
|
def handle_observed_exception(tx: Any) -> None:
|
||||||
|
|
|
||||||
|
|
@ -186,15 +186,6 @@ def set_difference_update(set1, *others):
|
||||||
set1.update(result)
|
set1.update(result)
|
||||||
|
|
||||||
|
|
||||||
def assert_multi_line_equal(self_, first, second, msg=None):
|
|
||||||
return self_.assertTrue(first == second, msg)
|
|
||||||
|
|
||||||
|
|
||||||
# The original impl. uses difflib
|
|
||||||
def assert_sequence_equal(self_, seq1, seq2, msg=None, seq_type=None):
|
|
||||||
return self_.assertTrue(seq1 == seq2, msg)
|
|
||||||
|
|
||||||
|
|
||||||
def getattr_and_trace(*args, **kwargs):
|
def getattr_and_trace(*args, **kwargs):
|
||||||
wrapper_obj = args[0]
|
wrapper_obj = args[0]
|
||||||
attr_name = args[1]
|
attr_name = args[1]
|
||||||
|
|
|
||||||
|
|
@ -23,8 +23,3 @@ def intern(string: str, /) -> str:
|
||||||
@substitute_in_graph(sys.getrecursionlimit, can_constant_fold_through=True)
|
@substitute_in_graph(sys.getrecursionlimit, can_constant_fold_through=True)
|
||||||
def getrecursionlimit() -> int:
|
def getrecursionlimit() -> int:
|
||||||
return sys.getrecursionlimit()
|
return sys.getrecursionlimit()
|
||||||
|
|
||||||
|
|
||||||
@substitute_in_graph(sys.get_int_max_str_digits, can_constant_fold_through=True)
|
|
||||||
def get_int_max_str_digits() -> int:
|
|
||||||
return sys.get_int_max_str_digits()
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.testing
|
import torch.testing
|
||||||
from torch._dynamo import polyfills
|
|
||||||
from torch._logging._internal import trace_log
|
from torch._logging._internal import trace_log
|
||||||
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
|
|
@ -137,8 +136,8 @@ class CPythonTestCase(TestCase):
|
||||||
assertRegex = unittest.TestCase.assertRegex
|
assertRegex = unittest.TestCase.assertRegex
|
||||||
assertNotRegex = unittest.TestCase.assertNotRegex
|
assertNotRegex = unittest.TestCase.assertNotRegex
|
||||||
assertCountEqual = unittest.TestCase.assertCountEqual
|
assertCountEqual = unittest.TestCase.assertCountEqual
|
||||||
assertMultiLineEqual = polyfills.assert_multi_line_equal
|
assertMultiLineEqual = unittest.TestCase.assertMultiLineEqual
|
||||||
assertSequenceEqual = polyfills.assert_sequence_equal
|
assertSequenceEqual = unittest.TestCase.assertSequenceEqual
|
||||||
assertListEqual = unittest.TestCase.assertListEqual
|
assertListEqual = unittest.TestCase.assertListEqual
|
||||||
assertTupleEqual = unittest.TestCase.assertTupleEqual
|
assertTupleEqual = unittest.TestCase.assertTupleEqual
|
||||||
assertSetEqual = unittest.TestCase.assertSetEqual
|
assertSetEqual = unittest.TestCase.assertSetEqual
|
||||||
|
|
|
||||||
|
|
@ -1277,12 +1277,6 @@ class BuiltinVariable(VariableTracker):
|
||||||
if isinstance(args[0], ConstantVariable):
|
if isinstance(args[0], ConstantVariable):
|
||||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||||
|
|
||||||
if self.fn is float and len(args) >= 1:
|
|
||||||
if isinstance(args[0], ConstantVariable):
|
|
||||||
return ConstantVariable.create(
|
|
||||||
getattr(float, name)(args[0].as_python_constant())
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().call_method(tx, name, args, kwargs)
|
return super().call_method(tx, name, args, kwargs)
|
||||||
|
|
||||||
def _call_int_float(self, tx: "InstructionTranslator", arg):
|
def _call_int_float(self, tx: "InstructionTranslator", arg):
|
||||||
|
|
@ -2068,6 +2062,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
"assertNotWarns",
|
"assertNotWarns",
|
||||||
"assertWarnsRegex",
|
"assertWarnsRegex",
|
||||||
"assertDictEqual",
|
"assertDictEqual",
|
||||||
|
"assertSequenceEqual",
|
||||||
"assertWarns",
|
"assertWarns",
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -173,14 +173,7 @@ its type to `common_constant_types`.
|
||||||
raise_observed_exception(type(e), tx)
|
raise_observed_exception(type(e), tx)
|
||||||
elif isinstance(self.value, (float, int)):
|
elif isinstance(self.value, (float, int)):
|
||||||
if not (args or kwargs):
|
if not (args or kwargs):
|
||||||
try:
|
|
||||||
return ConstantVariable.create(getattr(self.value, name)())
|
return ConstantVariable.create(getattr(self.value, name)())
|
||||||
except (OverflowError, ValueError) as exc:
|
|
||||||
raise_observed_exception(
|
|
||||||
type(exc),
|
|
||||||
tx,
|
|
||||||
args=list(map(ConstantVariable.create, exc.args)),
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
hasattr(operator, name)
|
hasattr(operator, name)
|
||||||
and len(args) == 1
|
and len(args) == 1
|
||||||
|
|
@ -210,14 +203,9 @@ its type to `common_constant_types`.
|
||||||
if name == "__len__" and not (args or kwargs):
|
if name == "__len__" and not (args or kwargs):
|
||||||
return ConstantVariable.create(len(self.value))
|
return ConstantVariable.create(len(self.value))
|
||||||
elif name == "__round__" and len(args) == 1 and args[0].is_python_constant():
|
elif name == "__round__" and len(args) == 1 and args[0].is_python_constant():
|
||||||
try:
|
|
||||||
return ConstantVariable.create(
|
return ConstantVariable.create(
|
||||||
round(self.value, args[0].as_python_constant())
|
round(self.value, args[0].as_python_constant())
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
raise_observed_exception(
|
|
||||||
type(e), tx, args=list(map(ConstantVariable.create, e.args))
|
|
||||||
)
|
|
||||||
elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
|
elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
search = args[0].as_python_constant()
|
search = args[0].as_python_constant()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user