[BE][Easy] export explicitly imported public submodules (#127703)

Add top-level submodules `torch.{storage,serialization,functional,amp,overrides,types}`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127703
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan 2024-06-10 19:16:54 +00:00 committed by PyTorch MergeBot
parent 62311257ad
commit dcc0093dba
3 changed files with 57 additions and 64 deletions

View File

@ -1,6 +1,4 @@
# mypy: allow-untyped-defs
r"""
"""
The torch package contains data structures for multi-dimensional
tensors and defines mathematical operations over these tensors.
Additionally, it provides many utilities for efficient serialization of
@ -10,6 +8,8 @@ It has a CUDA counterpart, that enables you to run your tensor computations
on an NVIDIA GPU with compute capability >= 3.0.
"""
# mypy: allow-untyped-defs
import math
import os
import sys
@ -289,10 +289,6 @@ else:
_load_global_deps()
from torch._C import * # noqa: F403
# Appease the type checker; ordinarily this binding is inserted by the
# torch._C module initialization code in C
if TYPE_CHECKING:
from . import _C as _C # noqa: TCH004
class SymInt:
"""
@ -614,10 +610,9 @@ def sym_not(a):
a (SymBool or bool): Object to negate
"""
import sympy
from .overrides import has_torch_function_unary, handle_torch_function
if has_torch_function_unary(a):
return handle_torch_function(sym_not, (a,), a)
if overrides.has_torch_function_unary(a):
return overrides.handle_torch_function(sym_not, (a,), a)
if hasattr(a, '__sym_not__'):
return a.__sym_not__()
if isinstance(a, sympy.Basic):
@ -630,10 +625,8 @@ def sym_float(a):
Args:
a (SymInt, SymFloat, or object): Object to cast
"""
from .overrides import has_torch_function_unary, handle_torch_function
if has_torch_function_unary(a):
return handle_torch_function(sym_float, (a,), a)
if overrides.has_torch_function_unary(a):
return overrides.handle_torch_function(sym_float, (a,), a)
if isinstance(a, SymFloat):
return a
elif hasattr(a, '__sym_float__'):
@ -647,10 +640,8 @@ def sym_int(a):
Args:
a (SymInt, SymFloat, or object): Object to cast
"""
from .overrides import has_torch_function_unary, handle_torch_function
if has_torch_function_unary(a):
return handle_torch_function(sym_int, (a,), a)
if overrides.has_torch_function_unary(a):
return overrides.handle_torch_function(sym_int, (a,), a)
if isinstance(a, SymInt):
return a
elif isinstance(a, SymFloat):
@ -664,10 +655,8 @@ def sym_max(a, b):
promotes to float if any argument is float (unlike builtins.max, which
will faithfully preserve the type of the input argument).
"""
from .overrides import has_torch_function, handle_torch_function
if has_torch_function((a, b)):
return handle_torch_function(sym_max, (a, b), a, b)
if overrides.has_torch_function((a, b)):
return overrides.handle_torch_function(sym_max, (a, b), a, b)
if isinstance(a, (SymInt, SymFloat)):
return a.__sym_max__(b)
elif isinstance(b, (SymInt, SymFloat)):
@ -683,11 +672,9 @@ def sym_max(a, b):
return builtins.max(a, b)
def sym_min(a, b):
""" SymInt-aware utility for min()."""
from .overrides import has_torch_function, handle_torch_function
if has_torch_function((a, b)):
return handle_torch_function(sym_min, (a, b), a, b)
"""SymInt-aware utility for min()."""
if overrides.has_torch_function((a, b)):
return overrides.handle_torch_function(sym_min, (a, b), a, b)
if isinstance(a, (SymInt, SymFloat)):
return a.__sym_min__(b)
elif isinstance(b, (SymInt, SymFloat)):
@ -702,10 +689,8 @@ def sym_min(a, b):
# Drop in replacement for math.sqrt, math.sin, math.cos etc
def _get_sym_math_fn(name):
def fn(a):
from .overrides import has_torch_function_unary, handle_torch_function
if has_torch_function_unary(a):
return handle_torch_function(fn, (a,), a)
if overrides.has_torch_function_unary(a):
return overrides.handle_torch_function(fn, (a,), a)
if hasattr(a, f"__sym_{name}__"):
return getattr(a, f"__sym_{name}__")()
return getattr(math, name)(a)
@ -727,10 +712,8 @@ __all__.append("sym_sqrt")
def sym_ite(b, t, f):
from .overrides import has_torch_function, handle_torch_function
if has_torch_function((b, t, f)):
return handle_torch_function(sym_ite, (b, t, f), b, t, f)
if overrides.has_torch_function((b, t, f)):
return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f)
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
if isinstance(b, SymBool):
return b.__sym_ite__(t, f)
@ -760,16 +743,20 @@ except ImportError:
''').strip()) from None
raise # If __file__ is not None the cause is unknown, so just re-raise.
# The torch._C submodule is already loaded via `from torch._C import *` above
# Make an explicit reference to the _C submodule to appease linters
from torch import _C as _C
__name, __obj = '', None
for __name in dir(_C):
if __name[0] != '_' and not __name.endswith('Base'):
__all__.append(__name)
__obj = getattr(_C, __name)
if callable(__obj) or inspect.isclass(__obj):
if __obj.__module__ != __name__:
if __obj.__module__ != __name__: # "torch"
# TODO: fix their module from C++ side
if __name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']:
__obj.__module__ = __name__
__obj.__module__ = __name__ # "torch"
elif __name == 'TensorBase':
# issue 109438 / pr 109940. Prevent TensorBase from being copied into torch.
delattr(sys.modules[__name__], __name)
@ -1478,6 +1465,7 @@ __all__.extend(['e', 'pi', 'nan', 'inf', 'newaxis'])
################################################################################
from ._tensor import Tensor
from torch import storage as storage
from .storage import _StorageBase, TypedStorage, _LegacyStorage, UntypedStorage, _warn_typed_storage_removal
# NOTE: New <type>Storage classes should never be added. When adding a new
@ -1665,7 +1653,9 @@ _storage_classes = {
_tensor_classes: Set[Type] = set()
# If you edit these imports, please update torch/__init__.py.in as well
from torch import random as random
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
from torch import serialization as serialization
from .serialization import save, load
from ._tensor_str import set_printoptions
@ -1682,6 +1672,7 @@ def _manager_path():
raise RuntimeError("Unable to find torch_shm_manager at " + path)
return path.encode('utf-8')
from torch import amp as amp
from torch.amp import autocast, GradScaler
# Initializing the extension shadows the built-in python float / int classes;
@ -1717,7 +1708,7 @@ for __name in dir(_C._VariableFunctions):
if __name.startswith('__') or __name in PRIVATE_OPS:
continue
__obj = getattr(_C._VariableFunctions, __name)
__obj.__module__ = __name__
__obj.__module__ = __name__ # "torch"
# Hide some APIs that should not be public
if __name == "segment_reduce":
# TODO: Once the undocumented FC window is passed, remove the line bellow
@ -1751,6 +1742,7 @@ from ._compile import _disable_dynamo
################################################################################
# needs to be after the above ATen bindings so we can overwrite from Python side
from torch import functional as functional
from .functional import * # noqa: F403
@ -1769,10 +1761,8 @@ del _LegacyStorage
def _assert(condition, message):
r"""A wrapper around Python's assert which is symbolically traceable.
"""
from .overrides import has_torch_function, handle_torch_function
if type(condition) is not torch.Tensor and has_torch_function((condition,)):
return handle_torch_function(_assert, (condition,), condition, message)
if type(condition) is not torch.Tensor and overrides.has_torch_function((condition,)):
return overrides.handle_torch_function(_assert, (condition,), condition, message)
assert condition, message
################################################################################
@ -1801,7 +1791,6 @@ from torch import nested as nested
from torch import nn as nn
from torch.signal import windows as windows
from torch import optim as optim
import torch.optim._multi_tensor
from torch import multiprocessing as multiprocessing
from torch import sparse as sparse
from torch import special as special
@ -1809,7 +1798,6 @@ import torch.utils.backcompat
from torch import jit as jit
from torch import linalg as linalg
from torch import hub as hub
from torch import random as random
from torch import distributions as distributions
from torch import testing as testing
from torch import backends as backends
@ -1817,6 +1805,8 @@ import torch.utils.data
from torch import __config__ as __config__
from torch import __future__ as __future__
from torch import profiler as profiler
from torch import overrides as overrides
from torch import types as types
# Quantized, sparse, AO, etc. should be last to get imported, as nothing
# is expected to depend on them.
@ -1827,7 +1817,7 @@ import torch.nn.quantized
import torch.nn.qat
import torch.nn.intrinsic
_C._init_names(list(torch._storage_classes))
_C._init_names(list(_storage_classes))
# attach docstrings to torch and tensor functions
from . import _torch_docs, _tensor_docs, _storage_docs, _size_docs
@ -1854,7 +1844,7 @@ from torch import quasirandom as quasirandom
# If you are seeing this, it means that this call site was not checked if
# the memory format could be preserved, and it was switched to old default
# behaviour of contiguous
legacy_contiguous_format = contiguous_format
legacy_contiguous_format = contiguous_format # defined by _C._initExtension()
# Register fork handler to initialize OpenMP in child processes (see gh-28389)
from torch.multiprocessing._atfork import register_after_fork
@ -1876,7 +1866,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
# Import experimental masked operations support. See
# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
# information.
from . import masked
from torch import masked as masked
# Import removed ops with error message about removal
from ._linalg_utils import ( # type: ignore[misc]

View File

@ -6,21 +6,22 @@ enough, so that more sophisticated ones can also be easily integrated in the
future.
"""
from . import lr_scheduler, swa_utils
from .adadelta import Adadelta
from .adagrad import Adagrad
from .adam import Adam
from .adamax import Adamax
from .adamw import AdamW
from .asgd import ASGD
from .lbfgs import LBFGS
from .nadam import NAdam
from .optimizer import Optimizer
from .radam import RAdam
from .rmsprop import RMSprop
from .rprop import Rprop
from .sgd import SGD
from .sparse_adam import SparseAdam
from torch.optim import lr_scheduler, swa_utils
from torch.optim.adadelta import Adadelta
from torch.optim.adagrad import Adagrad
from torch.optim.adam import Adam
from torch.optim.adamax import Adamax
from torch.optim.adamw import AdamW
from torch.optim.asgd import ASGD
from torch.optim.lbfgs import LBFGS
from torch.optim.nadam import NAdam
from torch.optim.optimizer import Optimizer
from torch.optim.radam import RAdam
from torch.optim.rmsprop import RMSprop
from torch.optim.rprop import Rprop
from torch.optim.sgd import SGD
from torch.optim.sparse_adam import SparseAdam
del adadelta # type: ignore[name-defined] # noqa: F821
del adagrad # type: ignore[name-defined] # noqa: F821
@ -36,3 +37,6 @@ del rmsprop # type: ignore[name-defined] # noqa: F821
del optimizer # type: ignore[name-defined] # noqa: F821
del nadam # type: ignore[name-defined] # noqa: F821
del lbfgs # type: ignore[name-defined] # noqa: F821
import torch.optim._multi_tensor

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import torch
from torch import Tensor
from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
@ -213,7 +212,7 @@ class WeightedRandomSampler(Sampler[int]):
[0, 1, 4, 3, 2]
"""
weights: Tensor
weights: torch.Tensor
num_samples: int
replacement: bool