From 75dae4f691e4481a68079ac55e1c44dc3211d16a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 4 Jan 2024 14:00:11 +0000 Subject: [PATCH] Revert "[dynamo] Fix np.issubdtype (#116459)" This reverts commit b5c33ccdb3198a48a354e21a4fdace0ec6d04146. Reverted https://github.com/pytorch/pytorch/pull/116459 on behalf of https://github.com/zou3519 due to Broke CI, seems to be a landrace ([comment](https://github.com/pytorch/pytorch/pull/116459#issuecomment-1877135999)) --- test/dynamo/test_misc.py | 11 ------- torch/_dynamo/variables/misc.py | 41 ++++++++----------------- torch/_numpy/_dtypes.py | 54 ++++++++++++++------------------- 3 files changed, 35 insertions(+), 71 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 6405f545e00..5748ed561fc 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1875,17 +1875,6 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 2) - def test_numpy_subdtype(self): - def fn(x, n): - return np.issubdtype(type(n), np.integer) + x - - args = [torch.randn(10), 4096] - correct = fn(*args) - cnts = torch._dynamo.testing.CompileCounter() - opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) - self.assertEqual(opt_fn(*args), correct) - self.assertEqual(cnts.frame_count, 1) - def test_numpy_take_along_axis(self): def fn(x, i, a): return np.take_along_axis(x, i, a) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 75ad8f06d95..f541e20aa60 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -16,7 +16,6 @@ from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource from ..utils import ( check_constant_args, - check_unspec_python_args, identity, is_tensor_base_attr_getter, proxy_args_kwargs, @@ -834,18 +833,10 @@ class NumpyVariable(VariableTracker): Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. """ - constant_fold_functions = (tnp.issubdtype,) - def __init__(self, value, **kwargs): super().__init__(**kwargs) self.value = value - @classmethod - def can_constant_fold_through(cls, fn): - mod = fn.__module__.split(".") - assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"] - return fn in cls.constant_fold_functions - def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": @@ -877,21 +868,8 @@ class NumpyVariable(VariableTracker): msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}" unimplemented(msg) - constant_args = check_constant_args(args, kwargs) - unspec_python_args = check_unspec_python_args(args, kwargs) - - if self.can_constant_fold_through(func) and ( - constant_args or unspec_python_args - ): - # constant fold - return variables.ConstantVariable.create( - self.as_python_constant()( - *[x.as_python_constant() for x in args], - **{k: v.as_python_constant() for k, v in kwargs.items()}, - ), - ) - - # TODO Add all the functions that go from constants to constants to can_constant_fold_through + # TODO(larryliu0820): currently assuming all numpy.* functions are returning a ndarray that can be + # wrapped by NumpyNdarrayVariable which is wrong! proxy = tx.output.create_proxy( "call_function", numpy_to_tensor_wrapper(func), @@ -915,11 +893,18 @@ class NumpyVariable(VariableTracker): return self.value def as_proxy(self): + # this handles numpy dtype attribute such as np.float32. TODO(larryliu0820): we should split NumpyVariable + # into NumpyVariable for instances/objects and NumpyVariable for types. if config.trace_numpy and isinstance(self.value, type): - # This handles numpy dtype attributes such as np.float32 - # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph - # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does - return self.value.__name__ + # retrieve attribute str. E.g., "float32" if given np.float32 + + attr = self.value.__name__ + # get tnp equivalent + tnp_dtype = tnp.dtype(attr) + # returning a string here because we are assuming all `dtype` kwargs for numpy + # functions can take an equivalent string and the behavior of the function would + # be the same as taking a numpy dtype. + return tnp_dtype.name return super().as_proxy() diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py index b13064d6321..e002db2c0a4 100644 --- a/torch/_numpy/_dtypes.py +++ b/torch/_numpy/_dtypes.py @@ -12,7 +12,9 @@ from . import _dtypes_impl class generic: - name = "generic" + @property + def name(self): + return self.__class__.__name__ def __new__(cls, value): # NumPy scalars are modelled as 0-D arrays @@ -35,44 +37,33 @@ class generic: class number(generic): - name = "number" + pass class integer(number): - name = "integer" + pass class inexact(number): - name = "inexact" + pass class signedinteger(integer): - name = "signedinteger" + pass class unsignedinteger(integer): - name = "unsignedinteger" + pass class floating(inexact): - name = "floating" + pass class complexfloating(inexact): - name = "complexfloating" + pass -_abstract_dtypes = [ - "generic", - "number", - "integer", - "signedinteger", - "unsignedinteger", - "inexact", - "floating", - "complexfloating", -] - # ##### concrete types # signed integers @@ -408,17 +399,6 @@ def issubclass_(arg, klass): def issubdtype(arg1, arg2): # cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420 - - # We also accept strings even if NumPy doesn't as dtypes are serialized as their - # string representation in dynamo's graph - def str_to_abstract(t): - if isinstance(t, str) and t in _abstract_dtypes: - return globals()[t] - return t - - arg1 = str_to_abstract(arg1) - arg2 = str_to_abstract(arg2) - if not issubclass_(arg1, generic): arg1 = dtype(arg1).type if not issubclass_(arg2, generic): @@ -426,7 +406,17 @@ def issubdtype(arg1, arg2): return issubclass(arg1, arg2) -__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype", "sctypes"] +__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype"] __all__ += list(_names.keys()) # noqa: PLE0605 __all__ += list(_name_aliases.keys()) # noqa: PLE0605 -__all__ += _abstract_dtypes # noqa: PLE0605 +__all__ += [ # noqa: PLE0605 + "sctypes", + "generic", + "number", + "integer", + "signedinteger", + "unsignedinteger", + "inexact", + "floating", + "complexfloating", +]