[BE][Easy] export explicitly imported public submodules (#127703)

Add top-level submodules `torch.{storage,serialization,functional,amp,overrides,types}`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127703
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan 2024-06-10 19:16:54 +00:00 committed by PyTorch MergeBot
parent 62311257ad
commit dcc0093dba
3 changed files with 57 additions and 64 deletions

View File

@ -1,6 +1,4 @@
# mypy: allow-untyped-defs """
r"""
The torch package contains data structures for multi-dimensional The torch package contains data structures for multi-dimensional
tensors and defines mathematical operations over these tensors. tensors and defines mathematical operations over these tensors.
Additionally, it provides many utilities for efficient serialization of Additionally, it provides many utilities for efficient serialization of
@ -10,6 +8,8 @@ It has a CUDA counterpart, that enables you to run your tensor computations
on an NVIDIA GPU with compute capability >= 3.0. on an NVIDIA GPU with compute capability >= 3.0.
""" """
# mypy: allow-untyped-defs
import math import math
import os import os
import sys import sys
@ -289,10 +289,6 @@ else:
_load_global_deps() _load_global_deps()
from torch._C import * # noqa: F403 from torch._C import * # noqa: F403
# Appease the type checker; ordinarily this binding is inserted by the
# torch._C module initialization code in C
if TYPE_CHECKING:
from . import _C as _C # noqa: TCH004
class SymInt: class SymInt:
""" """
@ -614,10 +610,9 @@ def sym_not(a):
a (SymBool or bool): Object to negate a (SymBool or bool): Object to negate
""" """
import sympy import sympy
from .overrides import has_torch_function_unary, handle_torch_function
if has_torch_function_unary(a): if overrides.has_torch_function_unary(a):
return handle_torch_function(sym_not, (a,), a) return overrides.handle_torch_function(sym_not, (a,), a)
if hasattr(a, '__sym_not__'): if hasattr(a, '__sym_not__'):
return a.__sym_not__() return a.__sym_not__()
if isinstance(a, sympy.Basic): if isinstance(a, sympy.Basic):
@ -630,10 +625,8 @@ def sym_float(a):
Args: Args:
a (SymInt, SymFloat, or object): Object to cast a (SymInt, SymFloat, or object): Object to cast
""" """
from .overrides import has_torch_function_unary, handle_torch_function if overrides.has_torch_function_unary(a):
return overrides.handle_torch_function(sym_float, (a,), a)
if has_torch_function_unary(a):
return handle_torch_function(sym_float, (a,), a)
if isinstance(a, SymFloat): if isinstance(a, SymFloat):
return a return a
elif hasattr(a, '__sym_float__'): elif hasattr(a, '__sym_float__'):
@ -647,10 +640,8 @@ def sym_int(a):
Args: Args:
a (SymInt, SymFloat, or object): Object to cast a (SymInt, SymFloat, or object): Object to cast
""" """
from .overrides import has_torch_function_unary, handle_torch_function if overrides.has_torch_function_unary(a):
return overrides.handle_torch_function(sym_int, (a,), a)
if has_torch_function_unary(a):
return handle_torch_function(sym_int, (a,), a)
if isinstance(a, SymInt): if isinstance(a, SymInt):
return a return a
elif isinstance(a, SymFloat): elif isinstance(a, SymFloat):
@ -664,10 +655,8 @@ def sym_max(a, b):
promotes to float if any argument is float (unlike builtins.max, which promotes to float if any argument is float (unlike builtins.max, which
will faithfully preserve the type of the input argument). will faithfully preserve the type of the input argument).
""" """
from .overrides import has_torch_function, handle_torch_function if overrides.has_torch_function((a, b)):
return overrides.handle_torch_function(sym_max, (a, b), a, b)
if has_torch_function((a, b)):
return handle_torch_function(sym_max, (a, b), a, b)
if isinstance(a, (SymInt, SymFloat)): if isinstance(a, (SymInt, SymFloat)):
return a.__sym_max__(b) return a.__sym_max__(b)
elif isinstance(b, (SymInt, SymFloat)): elif isinstance(b, (SymInt, SymFloat)):
@ -683,11 +672,9 @@ def sym_max(a, b):
return builtins.max(a, b) return builtins.max(a, b)
def sym_min(a, b): def sym_min(a, b):
""" SymInt-aware utility for min().""" """SymInt-aware utility for min()."""
from .overrides import has_torch_function, handle_torch_function if overrides.has_torch_function((a, b)):
return overrides.handle_torch_function(sym_min, (a, b), a, b)
if has_torch_function((a, b)):
return handle_torch_function(sym_min, (a, b), a, b)
if isinstance(a, (SymInt, SymFloat)): if isinstance(a, (SymInt, SymFloat)):
return a.__sym_min__(b) return a.__sym_min__(b)
elif isinstance(b, (SymInt, SymFloat)): elif isinstance(b, (SymInt, SymFloat)):
@ -702,10 +689,8 @@ def sym_min(a, b):
# Drop in replacement for math.sqrt, math.sin, math.cos etc # Drop in replacement for math.sqrt, math.sin, math.cos etc
def _get_sym_math_fn(name): def _get_sym_math_fn(name):
def fn(a): def fn(a):
from .overrides import has_torch_function_unary, handle_torch_function if overrides.has_torch_function_unary(a):
return overrides.handle_torch_function(fn, (a,), a)
if has_torch_function_unary(a):
return handle_torch_function(fn, (a,), a)
if hasattr(a, f"__sym_{name}__"): if hasattr(a, f"__sym_{name}__"):
return getattr(a, f"__sym_{name}__")() return getattr(a, f"__sym_{name}__")()
return getattr(math, name)(a) return getattr(math, name)(a)
@ -727,10 +712,8 @@ __all__.append("sym_sqrt")
def sym_ite(b, t, f): def sym_ite(b, t, f):
from .overrides import has_torch_function, handle_torch_function if overrides.has_torch_function((b, t, f)):
return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
if has_torch_function((b, t, f)):
return handle_torch_function(sym_ite, (b, t, f), b, t, f)
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f) assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
if isinstance(b, SymBool): if isinstance(b, SymBool):
return b.__sym_ite__(t, f) return b.__sym_ite__(t, f)
@ -760,16 +743,20 @@ except ImportError:
''').strip()) from None ''').strip()) from None
raise # If __file__ is not None the cause is unknown, so just re-raise. raise # If __file__ is not None the cause is unknown, so just re-raise.
# The torch._C submodule is already loaded via `from torch._C import *` above
# Make an explicit reference to the _C submodule to appease linters
from torch import _C as _C
__name, __obj = '', None __name, __obj = '', None
for __name in dir(_C): for __name in dir(_C):
if __name[0] != '_' and not __name.endswith('Base'): if __name[0] != '_' and not __name.endswith('Base'):
__all__.append(__name) __all__.append(__name)
__obj = getattr(_C, __name) __obj = getattr(_C, __name)
if callable(__obj) or inspect.isclass(__obj): if callable(__obj) or inspect.isclass(__obj):
if __obj.__module__ != __name__: if __obj.__module__ != __name__: # "torch"
# TODO: fix their module from C++ side # TODO: fix their module from C++ side
if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']: if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
__obj.__module__ = __name__ __obj.__module__ = __name__ # "torch"
elif __name == 'TensorBase': elif __name == 'TensorBase':
# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch. # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
delattr(sys.modules[__name__], __name) delattr(sys.modules[__name__], __name)
@ -1478,6 +1465,7 @@ __all__.extend(['e', 'pi', 'nan', 'inf', 'newaxis'])
################################################################################ ################################################################################
from ._tensor import Tensor from ._tensor import Tensor
from torch import storage as storage
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal
# NOTE: New <type>Storage classes should never be added. When adding a new # NOTE: New <type>Storage classes should never be added. When adding a new
@ -1665,7 +1653,9 @@ _storage_classes = {
_tensor_classes: Set[Type] = set() _tensor_classes: Set[Type] = set()
# If you edit these imports, please update torch/__init__.py.in as well # If you edit these imports, please update torch/__init__.py.in as well
from torch import random as random
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
from torch import serialization as serialization
from .serialization import save, load from .serialization import save, load
from ._tensor_str import set_printoptions from ._tensor_str import set_printoptions
@ -1682,6 +1672,7 @@ def _manager_path():
raise RuntimeError("Unable to find torch_shm_manager at " + path) raise RuntimeError("Unable to find torch_shm_manager at " + path)
return path.encode('utf-8') return path.encode('utf-8')
from torch import amp as amp
from torch.amp import autocast, GradScaler from torch.amp import autocast, GradScaler
# Initializing the extension shadows the built-in python float / int classes; # Initializing the extension shadows the built-in python float / int classes;
@ -1717,7 +1708,7 @@ for __name in dir(_C._VariableFunctions):
if __name.startswith('__') or __name in PRIVATE_OPS: if __name.startswith('__') or __name in PRIVATE_OPS:
continue continue
__obj = getattr(_C._VariableFunctions, __name) __obj = getattr(_C._VariableFunctions, __name)
__obj.__module__ = __name__ __obj.__module__ = __name__ # "torch"
# Hide some APIs that should not be public # Hide some APIs that should not be public
if __name == "segment_reduce": if __name == "segment_reduce":
# TODO: Once the undocumented FC window is passed, remove the line bellow # TODO: Once the undocumented FC window is passed, remove the line bellow
@ -1751,6 +1742,7 @@ from ._compile import _disable_dynamo
################################################################################ ################################################################################
# needs to be after the above ATen bindings so we can overwrite from Python side # needs to be after the above ATen bindings so we can overwrite from Python side
from torch import functional as functional
from .functional import * # noqa: F403 from .functional import * # noqa: F403
@ -1769,10 +1761,8 @@ del _LegacyStorage
def _assert(condition, message): def _assert(condition, message):
r"""A wrapper around Python's assert which is symbolically traceable. r"""A wrapper around Python's assert which is symbolically traceable.
""" """
from .overrides import has_torch_function, handle_torch_function if type(condition) is not torch.Tensor and overrides.has_torch_function((condition,)):
return overrides.handle_torch_function(_assert, (condition,), condition, message)
if type(condition) is not torch.Tensor and has_torch_function((condition,)):
return handle_torch_function(_assert, (condition,), condition, message)
assert condition, message assert condition, message
################################################################################ ################################################################################
@ -1801,7 +1791,6 @@ from torch import nested as nested
from torch import nn as nn from torch import nn as nn
from torch.signal import windows as windows from torch.signal import windows as windows
from torch import optim as optim from torch import optim as optim
import torch.optim._multi_tensor
from torch import multiprocessing as multiprocessing from torch import multiprocessing as multiprocessing
from torch import sparse as sparse from torch import sparse as sparse
from torch import special as special from torch import special as special
@ -1809,7 +1798,6 @@ import torch.utils.backcompat
from torch import jit as jit from torch import jit as jit
from torch import linalg as linalg from torch import linalg as linalg
from torch import hub as hub from torch import hub as hub
from torch import random as random
from torch import distributions as distributions from torch import distributions as distributions
from torch import testing as testing from torch import testing as testing
from torch import backends as backends from torch import backends as backends
@ -1817,6 +1805,8 @@ import torch.utils.data
from torch import __config__ as __config__ from torch import __config__ as __config__
from torch import __future__ as __future__ from torch import __future__ as __future__
from torch import profiler as profiler from torch import profiler as profiler
from torch import overrides as overrides
from torch import types as types
# Quantized, sparse, AO, etc. should be last to get imported, as nothing # Quantized, sparse, AO, etc. should be last to get imported, as nothing
# is expected to depend on them. # is expected to depend on them.
@ -1827,7 +1817,7 @@ import torch.nn.quantized
import torch.nn.qat import torch.nn.qat
import torch.nn.intrinsic import torch.nn.intrinsic
_C._init_names(list(torch._storage_classes)) _C._init_names(list(_storage_classes))
# attach docstrings to torch and tensor functions # attach docstrings to torch and tensor functions
from . import _torch_docs, _tensor_docs, _storage_docs, _size_docs from . import _torch_docs, _tensor_docs, _storage_docs, _size_docs
@ -1854,7 +1844,7 @@ from torch import quasirandom as quasirandom
# If you are seeing this, it means that this call site was not checked if # If you are seeing this, it means that this call site was not checked if
# the memory format could be preserved, and it was switched to old default # the memory format could be preserved, and it was switched to old default
# behaviour of contiguous # behaviour of contiguous
legacy_contiguous_format = contiguous_format legacy_contiguous_format = contiguous_format # defined by _C._initExtension()
# Register fork handler to initialize OpenMP in child processes (see gh-28389) # Register fork handler to initialize OpenMP in child processes (see gh-28389)
from torch.multiprocessing._atfork import register_after_fork from torch.multiprocessing._atfork import register_after_fork
@ -1876,7 +1866,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
# Import experimental masked operations support. See # Import experimental masked operations support. See
# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more # [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
# information. # information.
from . import masked from torch import masked as masked
# Import removed ops with error message about removal # Import removed ops with error message about removal
from ._linalg_utils import ( # type: ignore[misc] from ._linalg_utils import ( # type: ignore[misc]

View File

@ -6,21 +6,22 @@ enough, so that more sophisticated ones can also be easily integrated in the
future. future.
""" """
from . import lr_scheduler, swa_utils from torch.optim import lr_scheduler, swa_utils
from .adadelta import Adadelta from torch.optim.adadelta import Adadelta
from .adagrad import Adagrad from torch.optim.adagrad import Adagrad
from .adam import Adam from torch.optim.adam import Adam
from .adamax import Adamax from torch.optim.adamax import Adamax
from .adamw import AdamW from torch.optim.adamw import AdamW
from .asgd import ASGD from torch.optim.asgd import ASGD
from .lbfgs import LBFGS from torch.optim.lbfgs import LBFGS
from .nadam import NAdam from torch.optim.nadam import NAdam
from .optimizer import Optimizer from torch.optim.optimizer import Optimizer
from .radam import RAdam from torch.optim.radam import RAdam
from .rmsprop import RMSprop from torch.optim.rmsprop import RMSprop
from .rprop import Rprop from torch.optim.rprop import Rprop
from .sgd import SGD from torch.optim.sgd import SGD
from .sparse_adam import SparseAdam from torch.optim.sparse_adam import SparseAdam
del adadelta # type: ignore[name-defined] # noqa: F821 del adadelta # type: ignore[name-defined] # noqa: F821
del adagrad # type: ignore[name-defined] # noqa: F821 del adagrad # type: ignore[name-defined] # noqa: F821
@ -36,3 +37,6 @@ del rmsprop # type: ignore[name-defined] # noqa: F821
del optimizer # type: ignore[name-defined] # noqa: F821 del optimizer # type: ignore[name-defined] # noqa: F821
del nadam # type: ignore[name-defined] # noqa: F821 del nadam # type: ignore[name-defined] # noqa: F821
del lbfgs # type: ignore[name-defined] # noqa: F821 del lbfgs # type: ignore[name-defined] # noqa: F821
import torch.optim._multi_tensor

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import torch import torch
from torch import Tensor
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
@ -213,7 +212,7 @@ class WeightedRandomSampler(Sampler[int]):
[0, 1, 4, 3, 2] [0, 1, 4, 3, 2]
""" """
weights: Tensor weights: torch.Tensor
num_samples: int num_samples: int
replacement: bool replacement: bool