mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix wrongly exposed variables in torch/__init__.py (#127795)
<img width="609" alt="image" src="https://github.com/pytorch/pytorch/assets/16078332/964c6707-1856-4c2c-8cd8-ce1d96d38d36"> This PR removes temporary variables in `torch/__init__.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127795 Approved by: https://github.com/albanD
This commit is contained in:
parent
457df212e1
commit
c97e3ebb96
|
|
@ -1321,12 +1321,10 @@
|
|||
"_weight_norm_interface",
|
||||
"autocast",
|
||||
"broadcast_shapes",
|
||||
"candidate",
|
||||
"compiled_with_cxx11_abi",
|
||||
"from_dlpack",
|
||||
"lobpcg",
|
||||
"lu",
|
||||
"obj",
|
||||
"segment_reduce",
|
||||
"set_default_dtype",
|
||||
"set_grad_enabled",
|
||||
|
|
|
|||
|
|
@ -699,8 +699,6 @@ def sym_min(a, b):
|
|||
return builtins.min(a, b)
|
||||
|
||||
# Drop in replacement for math.sqrt, math.sin, math.cos etc
|
||||
current_module = sys.modules[__name__]
|
||||
|
||||
def _get_sym_math_fn(name):
|
||||
def fn(a):
|
||||
from .overrides import has_torch_function_unary, handle_torch_function
|
||||
|
|
@ -713,18 +711,19 @@ def _get_sym_math_fn(name):
|
|||
|
||||
return fn
|
||||
|
||||
for name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan"):
|
||||
sym_name = f"_sym_{name}"
|
||||
fn = _get_sym_math_fn(name)
|
||||
fn.__qualname__ = fn.__name__ = sym_name
|
||||
setattr(current_module, sym_name, fn)
|
||||
__fn, __name, __sym_name = None, '', ''
|
||||
for __name in ("sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan"):
|
||||
__sym_name = f"_sym_{__name}"
|
||||
__fn = _get_sym_math_fn(__name)
|
||||
__fn.__qualname__ = __fn.__name__ = __sym_name
|
||||
globals()[__sym_name] = __fn
|
||||
|
||||
del __fn, __name, __sym_name, _get_sym_math_fn
|
||||
|
||||
# Adding temporary shortcut
|
||||
sym_sqrt = current_module._sym_sqrt
|
||||
sym_sqrt = globals()["_sym_sqrt"]
|
||||
__all__.append("sym_sqrt")
|
||||
|
||||
del fn, name, sym_name, current_module # type: ignore[possibly-undefined]
|
||||
|
||||
|
||||
def sym_ite(b, t, f):
|
||||
from .overrides import has_torch_function, handle_torch_function
|
||||
|
|
@ -760,30 +759,35 @@ except ImportError:
|
|||
''').strip()) from None
|
||||
raise # If __file__ is not None the cause is unknown, so just re-raise.
|
||||
|
||||
for name in dir(_C):
|
||||
if name[0] != '_' and not name.endswith('Base'):
|
||||
__all__.append(name)
|
||||
obj = getattr(_C, name)
|
||||
if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type]
|
||||
if (obj.__module__ != 'torch'):
|
||||
__name, __obj = '', None
|
||||
for __name in dir(_C):
|
||||
if __name[0] != '_' and not __name.endswith('Base'):
|
||||
__all__.append(__name)
|
||||
__obj = getattr(_C, __name)
|
||||
if callable(__obj) or inspect.isclass(__obj):
|
||||
if __obj.__module__ != __name__:
|
||||
# TODO: fix their module from C++ side
|
||||
if name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
|
||||
obj.__module__ = 'torch'
|
||||
elif name == 'TensorBase':
|
||||
if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
|
||||
__obj.__module__ = __name__
|
||||
elif __name == 'TensorBase':
|
||||
# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
|
||||
delattr(sys.modules[__name__], name)
|
||||
delattr(sys.modules[__name__], __name)
|
||||
|
||||
del __name, __obj
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# issue 38137 and python issue 43367. Submodules of a C extension are
|
||||
# non-standard, and attributes of those submodules cannot be pickled since
|
||||
# pickle expect to be able to import them as "from _C.sub import attr"
|
||||
# which fails with "_C is not a package
|
||||
for attr in dir(_C):
|
||||
candidate = getattr(_C, attr)
|
||||
if type(candidate) is type(_C):
|
||||
__name, __candidate = '', None
|
||||
for __name in dir(_C):
|
||||
__candidate = getattr(_C, __name)
|
||||
if type(__candidate) is type(_C):
|
||||
# submodule
|
||||
if f'torch._C.{attr}' not in sys.modules:
|
||||
sys.modules[f'torch._C.{attr}'] = candidate
|
||||
sys.modules.setdefault(f"{__name__}._C.{__name}", __candidate)
|
||||
|
||||
del __name, __candidate
|
||||
|
||||
|
||||
################################################################################
|
||||
|
|
@ -1669,7 +1673,7 @@ from ._tensor_str import set_printoptions
|
|||
# Initialize extension
|
||||
################################################################################
|
||||
|
||||
def manager_path():
|
||||
def _manager_path():
|
||||
if _running_with_deploy() or platform.system() == 'Windows':
|
||||
return b""
|
||||
path = get_file_path('torch', 'bin', 'torch_shm_manager')
|
||||
|
|
@ -1686,8 +1690,8 @@ py_float = float
|
|||
py_int = int
|
||||
|
||||
# Shared memory manager needs to know the exact location of manager executable
|
||||
_C._initExtension(manager_path())
|
||||
del manager_path
|
||||
_C._initExtension(_manager_path())
|
||||
del _manager_path
|
||||
|
||||
# Appease the type checker: it can't deal with direct setting of globals().
|
||||
# Note that we will see "too many" functions when reexporting this way; there
|
||||
|
|
@ -1708,20 +1712,22 @@ PRIVATE_OPS = (
|
|||
'unique_dim',
|
||||
)
|
||||
|
||||
for name in dir(_C._VariableFunctions):
|
||||
if name.startswith('__') or name in PRIVATE_OPS:
|
||||
__name, __obj = '', None
|
||||
for __name in dir(_C._VariableFunctions):
|
||||
if __name.startswith('__') or __name in PRIVATE_OPS:
|
||||
continue
|
||||
obj = getattr(_C._VariableFunctions, name)
|
||||
obj.__module__ = 'torch'
|
||||
__obj = getattr(_C._VariableFunctions, __name)
|
||||
__obj.__module__ = __name__
|
||||
# Hide some APIs that should not be public
|
||||
if name == "segment_reduce":
|
||||
if __name == "segment_reduce":
|
||||
# TODO: Once the undocumented FC window is passed, remove the line bellow
|
||||
globals()[name] = obj
|
||||
name = "_" + name
|
||||
globals()[name] = obj
|
||||
if not name.startswith("_"):
|
||||
__all__.append(name)
|
||||
globals()[__name] = __obj
|
||||
__name = "_" + __name
|
||||
globals()[__name] = __obj
|
||||
if not __name.startswith("_"):
|
||||
__all__.append(__name)
|
||||
|
||||
del __name, __obj
|
||||
|
||||
################################################################################
|
||||
# Add torch.dtype instances to the public API
|
||||
|
|
@ -1729,9 +1735,9 @@ for name in dir(_C._VariableFunctions):
|
|||
|
||||
import torch
|
||||
|
||||
for attribute in dir(torch):
|
||||
if isinstance(getattr(torch, attribute), torch.dtype):
|
||||
__all__.append(attribute)
|
||||
__all__.extend(
|
||||
name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype)
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# Import TorchDynamo's lazy APIs to avoid circular dependenices
|
||||
|
|
|
|||
|
|
@ -1995,7 +1995,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch.not_equal",
|
||||
"torch.nuclear_norm",
|
||||
"torch.numel",
|
||||
"torch.obj",
|
||||
"torch.ones_like",
|
||||
"torch.ones",
|
||||
"torch.orgqr",
|
||||
|
|
@ -2182,6 +2181,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch.xlogy",
|
||||
"torch.zero_",
|
||||
"torch.zeros",
|
||||
"torch.zeros_like",
|
||||
"torch._fused_sgd_",
|
||||
"torch.slice_inverse",
|
||||
"torch._assert_scalar",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user