Fix wrongly exposed variables in torch/__init__.py (#127795)

<img width="609" alt="image" src="https://github.com/pytorch/pytorch/assets/16078332/964c6707-1856-4c2c-8cd8-ce1d96d38d36">

This PR removes temporary variables in `torch/__init__.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127795
Approved by: https://github.com/albanD
This commit is contained in:
Xuehai Pan 2024-06-05 05:53:23 +00:00 committed by PyTorch MergeBot
parent 457df212e1
commit c97e3ebb96
3 changed files with 48 additions and 44 deletions

View File

@ -1321,12 +1321,10 @@
"_weight_norm_interface",
"autocast",
"broadcast_shapes",
"candidate",
"compiled_with_cxx11_abi",
"from_dlpack",
"lobpcg",
"lu",
"obj",
"segment_reduce",
"set_default_dtype",
"set_grad_enabled",

View File

@ -699,8 +699,6 @@ def sym_min(a, b):
return builtins.min(a, b)
# Drop in replacement for math.sqrt, math.sin, math.cos etc
current_module = sys.modules[__name__]
def _get_sym_math_fn(name):
def fn(a):
from .overrides import has_torch_function_unary, handle_torch_function
@ -713,18 +711,19 @@ def _get_sym_math_fn(name):
return fn
for name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan"):
sym_name = f"_sym_{name}"
fn = _get_sym_math_fn(name)
fn.__qualname__ = fn.__name__ = sym_name
setattr(current_module, sym_name, fn)
__fn, __name, __sym_name = None, '', ''
for __name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan"):
__sym_name = f"_sym_{__name}"
__fn = _get_sym_math_fn(__name)
__fn.__qualname__ = __fn.__name__ = __sym_name
globals()[__sym_name] = __fn
del __fn, __name, __sym_name, _get_sym_math_fn
# Adding temporary shortcut
sym_sqrt = current_module._sym_sqrt
sym_sqrt = globals()["_sym_sqrt"]
__all__.append("sym_sqrt")
del fn, name, sym_name, current_module # type: ignore[possibly-undefined]
def sym_ite(b, t, f):
from .overrides import has_torch_function, handle_torch_function
@ -760,30 +759,35 @@ except ImportError:
''').strip()) from None
raise # If __file__ is not None the cause is unknown, so just re-raise.
for name in dir(_C):
if name[0] != '_' and not name.endswith('Base'):
__all__.append(name)
obj = getattr(_C, name)
if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type]
if (obj.__module__ != 'torch'):
__name, __obj = '', None
for __name in dir(_C):
if __name[0] != '_' and not __name.endswith('Base'):
__all__.append(__name)
__obj = getattr(_C, __name)
if callable(__obj) or inspect.isclass(__obj):
if __obj.__module__ != __name__:
# TODO: fix their module from C++ side
if name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
obj.__module__ = 'torch'
elif name == 'TensorBase':
if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
__obj.__module__ = __name__
elif __name == 'TensorBase':
# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
delattr(sys.modules[__name__], name)
delattr(sys.modules[__name__], __name)
del __name, __obj
if not TYPE_CHECKING:
# issue 38137 and python issue 43367. Submodules of a C extension are
# non-standard, and attributes of those submodules cannot be pickled since
# pickle expect to be able to import them as "from _C.sub import attr"
# which fails with "_C is not a package
for attr in dir(_C):
candidate = getattr(_C, attr)
if type(candidate) is type(_C):
__name, __candidate = '', None
for __name in dir(_C):
__candidate = getattr(_C, __name)
if type(__candidate) is type(_C):
# submodule
if f'torch._C.{attr}' not in sys.modules:
sys.modules[f'torch._C.{attr}'] = candidate
sys.modules.setdefault(f"{__name__}._C.{__name}", __candidate)
del __name, __candidate
################################################################################
@ -1669,7 +1673,7 @@ from ._tensor_str import set_printoptions
# Initialize extension
################################################################################
def manager_path():
def _manager_path():
if _running_with_deploy() or platform.system() == 'Windows':
return b""
path = get_file_path('torch', 'bin', 'torch_shm_manager')
@ -1686,8 +1690,8 @@ py_float = float
py_int = int
# Shared memory manager needs to know the exact location of manager executable
_C._initExtension(manager_path())
del manager_path
_C._initExtension(_manager_path())
del _manager_path
# Appease the type checker: it can't deal with direct setting of globals().
# Note that we will see "too many" functions when reexporting this way; there
@ -1708,20 +1712,22 @@ PRIVATE_OPS = (
'unique_dim',
)
for name in dir(_C._VariableFunctions):
if name.startswith('__') or name in PRIVATE_OPS:
__name, __obj = '', None
for __name in dir(_C._VariableFunctions):
if __name.startswith('__') or __name in PRIVATE_OPS:
continue
obj = getattr(_C._VariableFunctions, name)
obj.__module__ = 'torch'
__obj = getattr(_C._VariableFunctions, __name)
__obj.__module__ = __name__
# Hide some APIs that should not be public
if name == "segment_reduce":
if __name == "segment_reduce":
# TODO: Once the undocumented FC window is passed, remove the line bellow
globals()[name] = obj
name = "_" + name
globals()[name] = obj
if not name.startswith("_"):
__all__.append(name)
globals()[__name] = __obj
__name = "_" + __name
globals()[__name] = __obj
if not __name.startswith("_"):
__all__.append(__name)
del __name, __obj
################################################################################
# Add torch.dtype instances to the public API
@ -1729,9 +1735,9 @@ for name in dir(_C._VariableFunctions):
import torch
for attribute in dir(torch):
if isinstance(getattr(torch, attribute), torch.dtype):
__all__.append(attribute)
__all__.extend(
name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype)
)
################################################################################
# Import TorchDynamo's lazy APIs to avoid circular dependenices

View File

@ -1995,7 +1995,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch.not_equal",
"torch.nuclear_norm",
"torch.numel",
"torch.obj",
"torch.ones_like",
"torch.ones",
"torch.orgqr",
@ -2182,6 +2181,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch.xlogy",
"torch.zero_",
"torch.zeros",
"torch.zeros_like",
"torch._fused_sgd_",
"torch.slice_inverse",
"torch._assert_scalar",